aboutsummaryrefslogtreecommitdiff
path: root/test/s3/parquet/test_sse_s3_compatibility.py
blob: 534a6f814299a2a62939d180511e647764768e84 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
#!/usr/bin/env python3
"""
Test script for SSE-S3 compatibility with PyArrow native S3 filesystem.

This test specifically targets the SSE-S3 multipart upload bug where
SeaweedFS panics with "bad IV length" when reading multipart uploads
that were encrypted with bucket-default SSE-S3.

Requirements:
    - pyarrow>=10.0.0
    - boto3>=1.28.0

Environment Variables:
    S3_ENDPOINT_URL: S3 endpoint (default: localhost:8333)
    S3_ACCESS_KEY: S3 access key (default: some_access_key1)
    S3_SECRET_KEY: S3 secret key (default: some_secret_key1)
    BUCKET_NAME: S3 bucket name (default: test-parquet-bucket)

Usage:
    # Start SeaweedFS with SSE-S3 enabled
    make start-seaweedfs-ci ENABLE_SSE_S3=true
    
    # Run the test
    python3 test_sse_s3_compatibility.py
"""

import os
import secrets
import sys
import logging
from typing import Optional

import pyarrow as pa
import pyarrow.dataset as pads
import pyarrow.fs as pafs
import pyarrow.parquet as pq

try:
    import boto3
    from botocore.exceptions import ClientError
    HAS_BOTO3 = True
except ImportError:
    HAS_BOTO3 = False
    logging.exception("boto3 is required for this test")
    sys.exit(1)

from parquet_test_utils import create_sample_table

logging.basicConfig(level=logging.INFO, format="%(message)s")

# Configuration
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", "localhost:8333")
S3_ACCESS_KEY = os.environ.get("S3_ACCESS_KEY", "some_access_key1")
S3_SECRET_KEY = os.environ.get("S3_SECRET_KEY", "some_secret_key1")
BUCKET_NAME = os.getenv("BUCKET_NAME", "test-parquet-bucket")

TEST_RUN_ID = secrets.token_hex(8)
TEST_DIR = f"sse-s3-tests/{TEST_RUN_ID}"

# Test sizes designed to trigger multipart uploads
# PyArrow typically uses 5MB chunks, so these sizes should trigger multipart
TEST_SIZES = {
    "tiny": 10,                    # Single part
    "small": 1_000,               # Single part
    "medium": 50_000,             # Single part (~1.5MB)
    "large": 200_000,             # Multiple parts (~6MB)
    "very_large": 500_000,        # Multiple parts (~15MB)
}


def init_s3_filesystem() -> tuple[Optional[pafs.S3FileSystem], str, str]:
    """Initialize PyArrow's native S3 filesystem."""
    try:
        logging.info("Initializing PyArrow S3FileSystem...")
        
        # Determine scheme from endpoint
        if S3_ENDPOINT_URL.startswith("http://"):
            scheme = "http"
            endpoint = S3_ENDPOINT_URL[7:]
        elif S3_ENDPOINT_URL.startswith("https://"):
            scheme = "https"
            endpoint = S3_ENDPOINT_URL[8:]
        else:
            scheme = "http"
            endpoint = S3_ENDPOINT_URL
        
        s3 = pafs.S3FileSystem(
            access_key=S3_ACCESS_KEY,
            secret_key=S3_SECRET_KEY,
            endpoint_override=endpoint,
            scheme=scheme,
            allow_bucket_creation=True,
            allow_bucket_deletion=True,
        )
        
        logging.info("✓ PyArrow S3FileSystem initialized\n")
        return s3, scheme, endpoint
    except Exception:
        logging.exception("✗ Failed to initialize PyArrow S3FileSystem")
        return None, "", ""


def ensure_bucket_exists(scheme: str, endpoint: str) -> bool:
    """Ensure the test bucket exists using boto3."""
    try:
        endpoint_url = f"{scheme}://{endpoint}"
        s3_client = boto3.client(
            's3',
            endpoint_url=endpoint_url,
            aws_access_key_id=S3_ACCESS_KEY,
            aws_secret_access_key=S3_SECRET_KEY,
            region_name='us-east-1',
        )
        
        try:
            s3_client.head_bucket(Bucket=BUCKET_NAME)
            logging.info(f"✓ Bucket exists: {BUCKET_NAME}")
        except ClientError as e:
            error_code = e.response['Error']['Code']
            if error_code == '404':
                logging.info(f"Creating bucket: {BUCKET_NAME}")
                s3_client.create_bucket(Bucket=BUCKET_NAME)
                logging.info(f"✓ Bucket created: {BUCKET_NAME}")
            else:
                logging.exception("✗ Failed to access bucket")
                return False
        
        # Note: SeaweedFS doesn't support GetBucketEncryption API
        # so we can't verify if SSE-S3 is enabled via API
        # We assume it's configured correctly in the s3.json config file
        logging.info("✓ Assuming SSE-S3 is configured in s3.json")
        return True
            
    except Exception:
        logging.exception("✗ Failed to check bucket")
        return False


def test_write_read_with_sse(
    s3: pafs.S3FileSystem,
    test_name: str,
    num_rows: int
) -> tuple[bool, str, int]:
    """Test writing and reading with SSE-S3 encryption."""
    try:
        table = create_sample_table(num_rows)
        filename = f"{BUCKET_NAME}/{TEST_DIR}/{test_name}/data.parquet"
        
        logging.info(f"  Writing {num_rows:,} rows...")
        pads.write_dataset(
            table,
            filename,
            filesystem=s3,
            format="parquet",
        )
        
        logging.info("  Reading back...")
        table_read = pq.read_table(filename, filesystem=s3)
        
        if table_read.num_rows != num_rows:
            return False, f"Row count mismatch: {table_read.num_rows} != {num_rows}", 0
        
        return True, "Success", table_read.num_rows
        
    except Exception as e:
        error_msg = f"{type(e).__name__}: {e!s}"
        logging.exception("  ✗ Failed")
        return False, error_msg, 0


def main():
    """Run SSE-S3 compatibility tests."""
    print("=" * 80)
    print("SSE-S3 Compatibility Tests for PyArrow Native S3")
    print("Testing Multipart Upload Encryption")
    print("=" * 80 + "\n")

    print("Configuration:")
    print(f"  S3 Endpoint: {S3_ENDPOINT_URL}")
    print(f"  Bucket: {BUCKET_NAME}")
    print(f"  Test Directory: {TEST_DIR}")
    print(f"  PyArrow Version: {pa.__version__}")
    print()

    # Initialize
    s3, scheme, endpoint = init_s3_filesystem()
    if s3 is None:
        print("Cannot proceed without S3 connection")
        return 1

    # Check bucket and SSE-S3
    if not ensure_bucket_exists(scheme, endpoint):
        print("\n⚠ WARNING: Failed to access or create the test bucket!")
        print("This test requires a reachable bucket with SSE-S3 enabled.")
        print("Please ensure SeaweedFS is running with: make start-seaweedfs-ci ENABLE_SSE_S3=true")
        return 1

    print()
    results = []

    # Test all sizes
    for size_name, num_rows in TEST_SIZES.items():
        print(f"\n{'='*80}")
        print(f"Testing {size_name} dataset ({num_rows:,} rows)")
        print(f"{'='*80}")
        
        success, message, rows_read = test_write_read_with_sse(
            s3, size_name, num_rows
        )
        results.append((size_name, num_rows, success, message, rows_read))
        
        if success:
            print(f"  ✓ SUCCESS: Read {rows_read:,} rows")
        else:
            print(f"  ✗ FAILED: {message}")

    # Summary
    print("\n" + "=" * 80)
    print("SUMMARY")
    print("=" * 80)
    
    passed = sum(1 for _, _, success, _, _ in results if success)
    total = len(results)
    print(f"\nTotal: {passed}/{total} tests passed\n")
    
    print(f"{'Size':<15} {'Rows':>10} {'Status':<10} {'Rows Read':>10} {'Message':<40}")
    print("-" * 90)
    for size_name, num_rows, success, message, rows_read in results:
        status = "✓ PASS" if success else "✗ FAIL"
        rows_str = f"{rows_read:,}" if success else "N/A"
        print(f"{size_name:<15} {num_rows:>10,} {status:<10} {rows_str:>10} {message[:40]}")

    print("\n" + "=" * 80)
    if passed == total:
        print("✓ ALL TESTS PASSED WITH SSE-S3!")
        print("\nThis means:")
        print("  - SSE-S3 encryption is working correctly")
        print("  - PyArrow native S3 filesystem is compatible")
        print("  - Multipart uploads are handled properly")
    else:
        print(f"✗ {total - passed} test(s) failed")
        print("\nPossible issues:")
        print("  - SSE-S3 multipart upload bug with empty IV")
        print("  - Encryption/decryption mismatch")
        print("  - File corruption during upload")

    print("=" * 80 + "\n")

    return 0 if passed == total else 1


if __name__ == "__main__":
    sys.exit(main())