diff options
| author | chrislu <chris.lu@gmail.com> | 2024-08-10 10:01:57 -0700 |
|---|---|---|
| committer | chrislu <chris.lu@gmail.com> | 2024-08-10 10:01:57 -0700 |
| commit | 7438648d1cfacd5ca570dd029d1bdb5fd271bd70 (patch) | |
| tree | cf12b49473be0373cb03d83470ddc75708454171 /weed/security/tls.go | |
| parent | 49893267e978cc3fda00dc991e00099742fb5a9d (diff) | |
| parent | 63c707f9c1b4dc469ec39c446563c324ce4ccb6f (diff) | |
| download | seaweedfs-7438648d1cfacd5ca570dd029d1bdb5fd271bd70.tar.xz seaweedfs-7438648d1cfacd5ca570dd029d1bdb5fd271bd70.zip | |
Merge branch 'master' into mq
Diffstat (limited to 'weed/security/tls.go')
| -rw-r--r-- | weed/security/tls.go | 100 |
1 files changed, 85 insertions, 15 deletions
diff --git a/weed/security/tls.go b/weed/security/tls.go index ae6510219..1a9dfacb5 100644 --- a/weed/security/tls.go +++ b/weed/security/tls.go @@ -4,16 +4,17 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/credentials/tls/certprovider/pemfile" - "google.golang.org/grpc/security/advancedtls" "os" + "slices" "strings" "time" "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/util" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials/tls/certprovider/pemfile" + "google.golang.org/grpc/security/advancedtls" ) const CredRefreshingInterval = time.Duration(5) * time.Hour @@ -54,7 +55,7 @@ func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption } // Start a server and create a client using advancedtls API with Provider. - options := &advancedtls.ServerOptions{ + options := &advancedtls.Options{ IdentityOptions: advancedtls.IdentityCertificateOptions{ IdentityProvider: serverIdentityProvider, }, @@ -62,7 +63,22 @@ func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption RootProvider: serverRootProvider, }, RequireClientCert: true, - VType: advancedtls.CertVerification, + VerificationType: advancedtls.CertVerification, + } + options.MinTLSVersion, err = TlsVersionByName(config.GetString("tls.min_version")) + if err != nil { + glog.Warningf("tls min version parse failed, %v", err) + return nil, nil + } + options.MaxTLSVersion, err = TlsVersionByName(config.GetString("tls.max_version")) + if err != nil { + glog.Warningf("tls max version parse failed, %v", err) + return nil, nil + } + options.CipherSuites, err = TlsCipherSuiteByNames(config.GetString("tls.cipher_suites")) + if err != nil { + glog.Warningf("tls cipher suite parse failed, %v", err) + return nil, nil } allowedCommonNames := config.GetString(component + ".allowed_commonNames") allowedWildcardDomain := config.GetString("grpc.allowed_wildcard_domain") @@ -75,10 +91,10 @@ func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption AllowedCommonNames: allowedCommonNamesMap, AllowedWildcardDomain: allowedWildcardDomain, } - options.VerifyPeer = auther.Authenticate + options.AdditionalPeerVerification = auther.Authenticate } else { - options.VerifyPeer = func(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) { - return &advancedtls.VerificationResults{}, nil + options.AdditionalPeerVerification = func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) { + return &advancedtls.PostHandshakeVerificationResults{}, nil } } ta, err := advancedtls.NewServerCreds(options) @@ -118,17 +134,17 @@ func LoadClientTLS(config *util.ViperProxy, component string) grpc.DialOption { glog.Warningf("pemfile.NewProvider(%v) failed: %v", clientRootOptions, err) return grpc.WithTransportCredentials(insecure.NewCredentials()) } - options := &advancedtls.ClientOptions{ + options := &advancedtls.Options{ IdentityOptions: advancedtls.IdentityCertificateOptions{ IdentityProvider: clientProvider, }, - VerifyPeer: func(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) { - return &advancedtls.VerificationResults{}, nil + AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) { + return &advancedtls.PostHandshakeVerificationResults{}, nil }, RootOptions: advancedtls.RootCertificateOptions{ RootProvider: clientRootProvider, }, - VType: advancedtls.CertVerification, + VerificationType: advancedtls.CertVerification, } ta, err := advancedtls.NewClientCreds(options) if err != nil { @@ -155,14 +171,68 @@ func LoadClientTLSHTTP(clientCertFile string) *tls.Config { } } -func (a Authenticator) Authenticate(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) { +func (a Authenticator) Authenticate(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) { if a.AllowedWildcardDomain != "" && strings.HasSuffix(params.Leaf.Subject.CommonName, a.AllowedWildcardDomain) { - return &advancedtls.VerificationResults{}, nil + return &advancedtls.PostHandshakeVerificationResults{}, nil } if _, ok := a.AllowedCommonNames[params.Leaf.Subject.CommonName]; ok { - return &advancedtls.VerificationResults{}, nil + return &advancedtls.PostHandshakeVerificationResults{}, nil } err := fmt.Errorf("Authenticate: invalid subject client common name: %s", params.Leaf.Subject.CommonName) glog.Error(err) return nil, err } + +func FixTlsConfig(viper *util.ViperProxy, config *tls.Config) error { + var err error + config.MinVersion, err = TlsVersionByName(viper.GetString("tls.min_version")) + if err != nil { + return err + } + config.MaxVersion, err = TlsVersionByName(viper.GetString("tls.max_version")) + if err != nil { + return err + } + config.CipherSuites, err = TlsCipherSuiteByNames(viper.GetString("tls.cipher_suites")) + return err +} + +func TlsVersionByName(name string) (uint16, error) { + switch name { + case "": + return 0, nil + case "SSLv3": + return tls.VersionSSL30, nil + case "TLS 1.0": + return tls.VersionTLS10, nil + case "TLS 1.1": + return tls.VersionTLS11, nil + case "TLS 1.2": + return tls.VersionTLS12, nil + case "TLS 1.3": + return tls.VersionTLS13, nil + default: + return 0, fmt.Errorf("invalid tls version %s", name) + } +} + +func TlsCipherSuiteByNames(cipherSuiteNames string) ([]uint16, error) { + cipherSuiteNames = strings.TrimSpace(cipherSuiteNames) + if cipherSuiteNames == "" { + return nil, nil + } + names := strings.Split(cipherSuiteNames, ",") + cipherSuites := tls.CipherSuites() + cipherIds := make([]uint16, 0, len(names)) + for _, name := range names { + name = strings.TrimSpace(name) + index := slices.IndexFunc(cipherSuites, func(suite *tls.CipherSuite) bool { + return name == suite.Name + }) + if index == -1 { + return nil, fmt.Errorf("invalid tls cipher suite name %s", name) + } + cipherIds = append(cipherIds, cipherSuites[index].ID) + } + return cipherIds, nil +} |
