diff options
Diffstat (limited to 'test/s3/parquet/test_sse_s3_compatibility.py')
| -rwxr-xr-x | test/s3/parquet/test_sse_s3_compatibility.py | 254 |
1 files changed, 254 insertions, 0 deletions
diff --git a/test/s3/parquet/test_sse_s3_compatibility.py b/test/s3/parquet/test_sse_s3_compatibility.py new file mode 100755 index 000000000..534a6f814 --- /dev/null +++ b/test/s3/parquet/test_sse_s3_compatibility.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +""" +Test script for SSE-S3 compatibility with PyArrow native S3 filesystem. + +This test specifically targets the SSE-S3 multipart upload bug where +SeaweedFS panics with "bad IV length" when reading multipart uploads +that were encrypted with bucket-default SSE-S3. + +Requirements: + - pyarrow>=10.0.0 + - boto3>=1.28.0 + +Environment Variables: + S3_ENDPOINT_URL: S3 endpoint (default: localhost:8333) + S3_ACCESS_KEY: S3 access key (default: some_access_key1) + S3_SECRET_KEY: S3 secret key (default: some_secret_key1) + BUCKET_NAME: S3 bucket name (default: test-parquet-bucket) + +Usage: + # Start SeaweedFS with SSE-S3 enabled + make start-seaweedfs-ci ENABLE_SSE_S3=true + + # Run the test + python3 test_sse_s3_compatibility.py +""" + +import os +import secrets +import sys +import logging +from typing import Optional + +import pyarrow as pa +import pyarrow.dataset as pads +import pyarrow.fs as pafs +import pyarrow.parquet as pq + +try: + import boto3 + from botocore.exceptions import ClientError + HAS_BOTO3 = True +except ImportError: + HAS_BOTO3 = False + logging.exception("boto3 is required for this test") + sys.exit(1) + +from parquet_test_utils import create_sample_table + +logging.basicConfig(level=logging.INFO, format="%(message)s") + +# Configuration +S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", "localhost:8333") +S3_ACCESS_KEY = os.environ.get("S3_ACCESS_KEY", "some_access_key1") +S3_SECRET_KEY = os.environ.get("S3_SECRET_KEY", "some_secret_key1") +BUCKET_NAME = os.getenv("BUCKET_NAME", "test-parquet-bucket") + +TEST_RUN_ID = secrets.token_hex(8) +TEST_DIR = f"sse-s3-tests/{TEST_RUN_ID}" + +# Test sizes designed to trigger multipart uploads +# PyArrow typically uses 5MB chunks, so these sizes should trigger multipart +TEST_SIZES = { + "tiny": 10, # Single part + "small": 1_000, # Single part + "medium": 50_000, # Single part (~1.5MB) + "large": 200_000, # Multiple parts (~6MB) + "very_large": 500_000, # Multiple parts (~15MB) +} + + +def init_s3_filesystem() -> tuple[Optional[pafs.S3FileSystem], str, str]: + """Initialize PyArrow's native S3 filesystem.""" + try: + logging.info("Initializing PyArrow S3FileSystem...") + + # Determine scheme from endpoint + if S3_ENDPOINT_URL.startswith("http://"): + scheme = "http" + endpoint = S3_ENDPOINT_URL[7:] + elif S3_ENDPOINT_URL.startswith("https://"): + scheme = "https" + endpoint = S3_ENDPOINT_URL[8:] + else: + scheme = "http" + endpoint = S3_ENDPOINT_URL + + s3 = pafs.S3FileSystem( + access_key=S3_ACCESS_KEY, + secret_key=S3_SECRET_KEY, + endpoint_override=endpoint, + scheme=scheme, + allow_bucket_creation=True, + allow_bucket_deletion=True, + ) + + logging.info("✓ PyArrow S3FileSystem initialized\n") + return s3, scheme, endpoint + except Exception: + logging.exception("✗ Failed to initialize PyArrow S3FileSystem") + return None, "", "" + + +def ensure_bucket_exists(scheme: str, endpoint: str) -> bool: + """Ensure the test bucket exists using boto3.""" + try: + endpoint_url = f"{scheme}://{endpoint}" + s3_client = boto3.client( + 's3', + endpoint_url=endpoint_url, + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_SECRET_KEY, + region_name='us-east-1', + ) + + try: + s3_client.head_bucket(Bucket=BUCKET_NAME) + logging.info(f"✓ Bucket exists: {BUCKET_NAME}") + except ClientError as e: + error_code = e.response['Error']['Code'] + if error_code == '404': + logging.info(f"Creating bucket: {BUCKET_NAME}") + s3_client.create_bucket(Bucket=BUCKET_NAME) + logging.info(f"✓ Bucket created: {BUCKET_NAME}") + else: + logging.exception("✗ Failed to access bucket") + return False + + # Note: SeaweedFS doesn't support GetBucketEncryption API + # so we can't verify if SSE-S3 is enabled via API + # We assume it's configured correctly in the s3.json config file + logging.info("✓ Assuming SSE-S3 is configured in s3.json") + return True + + except Exception: + logging.exception("✗ Failed to check bucket") + return False + + +def test_write_read_with_sse( + s3: pafs.S3FileSystem, + test_name: str, + num_rows: int +) -> tuple[bool, str, int]: + """Test writing and reading with SSE-S3 encryption.""" + try: + table = create_sample_table(num_rows) + filename = f"{BUCKET_NAME}/{TEST_DIR}/{test_name}/data.parquet" + + logging.info(f" Writing {num_rows:,} rows...") + pads.write_dataset( + table, + filename, + filesystem=s3, + format="parquet", + ) + + logging.info(" Reading back...") + table_read = pq.read_table(filename, filesystem=s3) + + if table_read.num_rows != num_rows: + return False, f"Row count mismatch: {table_read.num_rows} != {num_rows}", 0 + + return True, "Success", table_read.num_rows + + except Exception as e: + error_msg = f"{type(e).__name__}: {e!s}" + logging.exception(" ✗ Failed") + return False, error_msg, 0 + + +def main(): + """Run SSE-S3 compatibility tests.""" + print("=" * 80) + print("SSE-S3 Compatibility Tests for PyArrow Native S3") + print("Testing Multipart Upload Encryption") + print("=" * 80 + "\n") + + print("Configuration:") + print(f" S3 Endpoint: {S3_ENDPOINT_URL}") + print(f" Bucket: {BUCKET_NAME}") + print(f" Test Directory: {TEST_DIR}") + print(f" PyArrow Version: {pa.__version__}") + print() + + # Initialize + s3, scheme, endpoint = init_s3_filesystem() + if s3 is None: + print("Cannot proceed without S3 connection") + return 1 + + # Check bucket and SSE-S3 + if not ensure_bucket_exists(scheme, endpoint): + print("\n⚠ WARNING: Failed to access or create the test bucket!") + print("This test requires a reachable bucket with SSE-S3 enabled.") + print("Please ensure SeaweedFS is running with: make start-seaweedfs-ci ENABLE_SSE_S3=true") + return 1 + + print() + results = [] + + # Test all sizes + for size_name, num_rows in TEST_SIZES.items(): + print(f"\n{'='*80}") + print(f"Testing {size_name} dataset ({num_rows:,} rows)") + print(f"{'='*80}") + + success, message, rows_read = test_write_read_with_sse( + s3, size_name, num_rows + ) + results.append((size_name, num_rows, success, message, rows_read)) + + if success: + print(f" ✓ SUCCESS: Read {rows_read:,} rows") + else: + print(f" ✗ FAILED: {message}") + + # Summary + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + + passed = sum(1 for _, _, success, _, _ in results if success) + total = len(results) + print(f"\nTotal: {passed}/{total} tests passed\n") + + print(f"{'Size':<15} {'Rows':>10} {'Status':<10} {'Rows Read':>10} {'Message':<40}") + print("-" * 90) + for size_name, num_rows, success, message, rows_read in results: + status = "✓ PASS" if success else "✗ FAIL" + rows_str = f"{rows_read:,}" if success else "N/A" + print(f"{size_name:<15} {num_rows:>10,} {status:<10} {rows_str:>10} {message[:40]}") + + print("\n" + "=" * 80) + if passed == total: + print("✓ ALL TESTS PASSED WITH SSE-S3!") + print("\nThis means:") + print(" - SSE-S3 encryption is working correctly") + print(" - PyArrow native S3 filesystem is compatible") + print(" - Multipart uploads are handled properly") + else: + print(f"✗ {total - passed} test(s) failed") + print("\nPossible issues:") + print(" - SSE-S3 multipart upload bug with empty IV") + print(" - Encryption/decryption mismatch") + print(" - File corruption during upload") + + print("=" * 80 + "\n") + + return 0 if passed == total else 1 + + +if __name__ == "__main__": + sys.exit(main()) + |
