aboutsummaryrefslogtreecommitdiff
path: root/weed/s3api/s3_sse_c_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/s3api/s3_sse_c_test.go')
-rw-r--r--weed/s3api/s3_sse_c_test.go412
1 files changed, 412 insertions, 0 deletions
diff --git a/weed/s3api/s3_sse_c_test.go b/weed/s3api/s3_sse_c_test.go
new file mode 100644
index 000000000..51c536445
--- /dev/null
+++ b/weed/s3api/s3_sse_c_test.go
@@ -0,0 +1,412 @@
+package s3api
+
+import (
+ "bytes"
+ "crypto/md5"
+ "encoding/base64"
+ "fmt"
+ "io"
+ "net/http"
+ "testing"
+
+ "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
+)
+
+func base64MD5(b []byte) string {
+ s := md5.Sum(b)
+ return base64.StdEncoding.EncodeToString(s[:])
+}
+
+func TestSSECHeaderValidation(t *testing.T) {
+ // Test valid SSE-C headers
+ req := &http.Request{Header: make(http.Header)}
+
+ key := make([]byte, 32) // 256-bit key
+ for i := range key {
+ key[i] = byte(i)
+ }
+
+ keyBase64 := base64.StdEncoding.EncodeToString(key)
+ md5sum := md5.Sum(key)
+ keyMD5 := base64.StdEncoding.EncodeToString(md5sum[:])
+
+ req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
+ req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyBase64)
+ req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyMD5)
+
+ // Test validation
+ err := ValidateSSECHeaders(req)
+ if err != nil {
+ t.Errorf("Expected valid headers, got error: %v", err)
+ }
+
+ // Test parsing
+ customerKey, err := ParseSSECHeaders(req)
+ if err != nil {
+ t.Errorf("Expected successful parsing, got error: %v", err)
+ }
+
+ if customerKey == nil {
+ t.Error("Expected customer key, got nil")
+ }
+
+ if customerKey.Algorithm != "AES256" {
+ t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm)
+ }
+
+ if !bytes.Equal(customerKey.Key, key) {
+ t.Error("Key doesn't match original")
+ }
+
+ if customerKey.KeyMD5 != keyMD5 {
+ t.Errorf("Expected key MD5 %s, got %s", keyMD5, customerKey.KeyMD5)
+ }
+}
+
+func TestSSECCopySourceHeaders(t *testing.T) {
+ // Test valid SSE-C copy source headers
+ req := &http.Request{Header: make(http.Header)}
+
+ key := make([]byte, 32) // 256-bit key
+ for i := range key {
+ key[i] = byte(i) + 1 // Different from regular test
+ }
+
+ keyBase64 := base64.StdEncoding.EncodeToString(key)
+ md5sum2 := md5.Sum(key)
+ keyMD5 := base64.StdEncoding.EncodeToString(md5sum2[:])
+
+ req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerAlgorithm, "AES256")
+ req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKey, keyBase64)
+ req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKeyMD5, keyMD5)
+
+ // Test parsing copy source headers
+ customerKey, err := ParseSSECCopySourceHeaders(req)
+ if err != nil {
+ t.Errorf("Expected successful copy source parsing, got error: %v", err)
+ }
+
+ if customerKey == nil {
+ t.Error("Expected customer key from copy source headers, got nil")
+ }
+
+ if customerKey.Algorithm != "AES256" {
+ t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm)
+ }
+
+ if !bytes.Equal(customerKey.Key, key) {
+ t.Error("Copy source key doesn't match original")
+ }
+
+ // Test that regular headers don't interfere with copy source headers
+ regularKey, err := ParseSSECHeaders(req)
+ if err != nil {
+ t.Errorf("Regular header parsing should not fail: %v", err)
+ }
+
+ if regularKey != nil {
+ t.Error("Expected nil for regular headers when only copy source headers are present")
+ }
+}
+
+func TestSSECHeaderValidationErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ algorithm string
+ key string
+ keyMD5 string
+ wantErr error
+ }{
+ {
+ name: "invalid algorithm",
+ algorithm: "AES128",
+ key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
+ keyMD5: base64MD5(make([]byte, 32)),
+ wantErr: ErrInvalidEncryptionAlgorithm,
+ },
+ {
+ name: "invalid key length",
+ algorithm: "AES256",
+ key: base64.StdEncoding.EncodeToString(make([]byte, 16)),
+ keyMD5: base64MD5(make([]byte, 16)),
+ wantErr: ErrInvalidEncryptionKey,
+ },
+ {
+ name: "mismatched MD5",
+ algorithm: "AES256",
+ key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
+ keyMD5: "wrong==md5",
+ wantErr: ErrSSECustomerKeyMD5Mismatch,
+ },
+ {
+ name: "incomplete headers",
+ algorithm: "AES256",
+ key: "",
+ keyMD5: "",
+ wantErr: ErrInvalidRequest,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := &http.Request{Header: make(http.Header)}
+
+ if tt.algorithm != "" {
+ req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, tt.algorithm)
+ }
+ if tt.key != "" {
+ req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, tt.key)
+ }
+ if tt.keyMD5 != "" {
+ req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, tt.keyMD5)
+ }
+
+ err := ValidateSSECHeaders(req)
+ if err != tt.wantErr {
+ t.Errorf("Expected error %v, got %v", tt.wantErr, err)
+ }
+ })
+ }
+}
+
+func TestSSECEncryptionDecryption(t *testing.T) {
+ // Create customer key
+ key := make([]byte, 32)
+ for i := range key {
+ key[i] = byte(i)
+ }
+
+ md5sumKey := md5.Sum(key)
+ customerKey := &SSECustomerKey{
+ Algorithm: "AES256",
+ Key: key,
+ KeyMD5: base64.StdEncoding.EncodeToString(md5sumKey[:]),
+ }
+
+ // Test data
+ testData := []byte("Hello, World! This is a test of SSE-C encryption.")
+
+ // Create encrypted reader
+ dataReader := bytes.NewReader(testData)
+ encryptedReader, err := CreateSSECEncryptedReader(dataReader, customerKey)
+ if err != nil {
+ t.Fatalf("Failed to create encrypted reader: %v", err)
+ }
+
+ // Read encrypted data
+ encryptedData, err := io.ReadAll(encryptedReader)
+ if err != nil {
+ t.Fatalf("Failed to read encrypted data: %v", err)
+ }
+
+ // Verify data is actually encrypted (different from original)
+ if bytes.Equal(encryptedData[16:], testData) { // Skip IV
+ t.Error("Data doesn't appear to be encrypted")
+ }
+
+ // Create decrypted reader
+ encryptedReader2 := bytes.NewReader(encryptedData)
+ decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey)
+ if err != nil {
+ t.Fatalf("Failed to create decrypted reader: %v", err)
+ }
+
+ // Read decrypted data
+ decryptedData, err := io.ReadAll(decryptedReader)
+ if err != nil {
+ t.Fatalf("Failed to read decrypted data: %v", err)
+ }
+
+ // Verify decrypted data matches original
+ if !bytes.Equal(decryptedData, testData) {
+ t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData)
+ }
+}
+
+func TestSSECIsSSECRequest(t *testing.T) {
+ // Test with SSE-C headers
+ req := &http.Request{Header: make(http.Header)}
+ req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
+
+ if !IsSSECRequest(req) {
+ t.Error("Expected IsSSECRequest to return true when SSE-C headers are present")
+ }
+
+ // Test without SSE-C headers
+ req2 := &http.Request{Header: make(http.Header)}
+ if IsSSECRequest(req2) {
+ t.Error("Expected IsSSECRequest to return false when no SSE-C headers are present")
+ }
+}
+
+// Test encryption with different data sizes (similar to s3tests)
+func TestSSECEncryptionVariousSizes(t *testing.T) {
+ sizes := []int{1, 13, 1024, 1024 * 1024} // 1B, 13B, 1KB, 1MB
+
+ for _, size := range sizes {
+ t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
+ // Create customer key
+ key := make([]byte, 32)
+ for i := range key {
+ key[i] = byte(i + size) // Make key unique per test
+ }
+
+ md5sumDyn := md5.Sum(key)
+ customerKey := &SSECustomerKey{
+ Algorithm: "AES256",
+ Key: key,
+ KeyMD5: base64.StdEncoding.EncodeToString(md5sumDyn[:]),
+ }
+
+ // Create test data of specified size
+ testData := make([]byte, size)
+ for i := range testData {
+ testData[i] = byte('A' + (i % 26)) // Pattern of A-Z
+ }
+
+ // Encrypt
+ dataReader := bytes.NewReader(testData)
+ encryptedReader, err := CreateSSECEncryptedReader(dataReader, customerKey)
+ if err != nil {
+ t.Fatalf("Failed to create encrypted reader: %v", err)
+ }
+
+ encryptedData, err := io.ReadAll(encryptedReader)
+ if err != nil {
+ t.Fatalf("Failed to read encrypted data: %v", err)
+ }
+
+ // Verify IV is present and data is encrypted
+ if len(encryptedData) < AESBlockSize {
+ t.Fatalf("Encrypted data too short, missing IV")
+ }
+
+ if len(encryptedData) != size+AESBlockSize {
+ t.Errorf("Expected encrypted data length %d, got %d", size+AESBlockSize, len(encryptedData))
+ }
+
+ // Decrypt
+ encryptedReader2 := bytes.NewReader(encryptedData)
+ decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey)
+ if err != nil {
+ t.Fatalf("Failed to create decrypted reader: %v", err)
+ }
+
+ decryptedData, err := io.ReadAll(decryptedReader)
+ if err != nil {
+ t.Fatalf("Failed to read decrypted data: %v", err)
+ }
+
+ // Verify decrypted data matches original
+ if !bytes.Equal(decryptedData, testData) {
+ t.Errorf("Decrypted data doesn't match original for size %d", size)
+ }
+ })
+ }
+}
+
+func TestSSECEncryptionWithNilKey(t *testing.T) {
+ testData := []byte("test data")
+ dataReader := bytes.NewReader(testData)
+
+ // Test encryption with nil key (should pass through)
+ encryptedReader, err := CreateSSECEncryptedReader(dataReader, nil)
+ if err != nil {
+ t.Fatalf("Failed to create encrypted reader with nil key: %v", err)
+ }
+
+ result, err := io.ReadAll(encryptedReader)
+ if err != nil {
+ t.Fatalf("Failed to read from pass-through reader: %v", err)
+ }
+
+ if !bytes.Equal(result, testData) {
+ t.Error("Data should pass through unchanged when key is nil")
+ }
+
+ // Test decryption with nil key (should pass through)
+ dataReader2 := bytes.NewReader(testData)
+ decryptedReader, err := CreateSSECDecryptedReader(dataReader2, nil)
+ if err != nil {
+ t.Fatalf("Failed to create decrypted reader with nil key: %v", err)
+ }
+
+ result2, err := io.ReadAll(decryptedReader)
+ if err != nil {
+ t.Fatalf("Failed to read from pass-through reader: %v", err)
+ }
+
+ if !bytes.Equal(result2, testData) {
+ t.Error("Data should pass through unchanged when key is nil")
+ }
+}
+
+// TestSSECEncryptionSmallBuffers tests the fix for the critical bug where small buffers
+// could corrupt the data stream when reading in chunks smaller than the IV size
+func TestSSECEncryptionSmallBuffers(t *testing.T) {
+ testData := []byte("This is a test message for small buffer reads")
+
+ // Create customer key
+ key := make([]byte, 32)
+ for i := range key {
+ key[i] = byte(i)
+ }
+
+ md5sumKey3 := md5.Sum(key)
+ customerKey := &SSECustomerKey{
+ Algorithm: "AES256",
+ Key: key,
+ KeyMD5: base64.StdEncoding.EncodeToString(md5sumKey3[:]),
+ }
+
+ // Create encrypted reader
+ dataReader := bytes.NewReader(testData)
+ encryptedReader, err := CreateSSECEncryptedReader(dataReader, customerKey)
+ if err != nil {
+ t.Fatalf("Failed to create encrypted reader: %v", err)
+ }
+
+ // Read with very small buffers (smaller than IV size of 16 bytes)
+ var encryptedData []byte
+ smallBuffer := make([]byte, 5) // Much smaller than 16-byte IV
+
+ for {
+ n, err := encryptedReader.Read(smallBuffer)
+ if n > 0 {
+ encryptedData = append(encryptedData, smallBuffer[:n]...)
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatalf("Error reading encrypted data: %v", err)
+ }
+ }
+
+ // Verify the encrypted data starts with 16-byte IV
+ if len(encryptedData) < 16 {
+ t.Fatalf("Encrypted data too short, expected at least 16 bytes for IV, got %d", len(encryptedData))
+ }
+
+ // Expected total size: 16 bytes (IV) + len(testData)
+ expectedSize := 16 + len(testData)
+ if len(encryptedData) != expectedSize {
+ t.Errorf("Expected encrypted data size %d, got %d", expectedSize, len(encryptedData))
+ }
+
+ // Decrypt and verify
+ encryptedReader2 := bytes.NewReader(encryptedData)
+ decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey)
+ if err != nil {
+ t.Fatalf("Failed to create decrypted reader: %v", err)
+ }
+
+ decryptedData, err := io.ReadAll(decryptedReader)
+ if err != nil {
+ t.Fatalf("Failed to read decrypted data: %v", err)
+ }
+
+ if !bytes.Equal(decryptedData, testData) {
+ t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData)
+ }
+}