aboutsummaryrefslogtreecommitdiff
path: root/weed/security/tls.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/security/tls.go')
-rw-r--r--weed/security/tls.go100
1 files changed, 85 insertions, 15 deletions
diff --git a/weed/security/tls.go b/weed/security/tls.go
index ae6510219..1a9dfacb5 100644
--- a/weed/security/tls.go
+++ b/weed/security/tls.go
@@ -4,16 +4,17 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
- "google.golang.org/grpc/credentials/insecure"
- "google.golang.org/grpc/credentials/tls/certprovider/pemfile"
- "google.golang.org/grpc/security/advancedtls"
"os"
+ "slices"
"strings"
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/util"
"google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+ "google.golang.org/grpc/credentials/tls/certprovider/pemfile"
+ "google.golang.org/grpc/security/advancedtls"
)
const CredRefreshingInterval = time.Duration(5) * time.Hour
@@ -54,7 +55,7 @@ func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption
}
// Start a server and create a client using advancedtls API with Provider.
- options := &advancedtls.ServerOptions{
+ options := &advancedtls.Options{
IdentityOptions: advancedtls.IdentityCertificateOptions{
IdentityProvider: serverIdentityProvider,
},
@@ -62,7 +63,22 @@ func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption
RootProvider: serverRootProvider,
},
RequireClientCert: true,
- VType: advancedtls.CertVerification,
+ VerificationType: advancedtls.CertVerification,
+ }
+ options.MinTLSVersion, err = TlsVersionByName(config.GetString("tls.min_version"))
+ if err != nil {
+ glog.Warningf("tls min version parse failed, %v", err)
+ return nil, nil
+ }
+ options.MaxTLSVersion, err = TlsVersionByName(config.GetString("tls.max_version"))
+ if err != nil {
+ glog.Warningf("tls max version parse failed, %v", err)
+ return nil, nil
+ }
+ options.CipherSuites, err = TlsCipherSuiteByNames(config.GetString("tls.cipher_suites"))
+ if err != nil {
+ glog.Warningf("tls cipher suite parse failed, %v", err)
+ return nil, nil
}
allowedCommonNames := config.GetString(component + ".allowed_commonNames")
allowedWildcardDomain := config.GetString("grpc.allowed_wildcard_domain")
@@ -75,10 +91,10 @@ func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption
AllowedCommonNames: allowedCommonNamesMap,
AllowedWildcardDomain: allowedWildcardDomain,
}
- options.VerifyPeer = auther.Authenticate
+ options.AdditionalPeerVerification = auther.Authenticate
} else {
- options.VerifyPeer = func(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) {
- return &advancedtls.VerificationResults{}, nil
+ options.AdditionalPeerVerification = func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
+ return &advancedtls.PostHandshakeVerificationResults{}, nil
}
}
ta, err := advancedtls.NewServerCreds(options)
@@ -118,17 +134,17 @@ func LoadClientTLS(config *util.ViperProxy, component string) grpc.DialOption {
glog.Warningf("pemfile.NewProvider(%v) failed: %v", clientRootOptions, err)
return grpc.WithTransportCredentials(insecure.NewCredentials())
}
- options := &advancedtls.ClientOptions{
+ options := &advancedtls.Options{
IdentityOptions: advancedtls.IdentityCertificateOptions{
IdentityProvider: clientProvider,
},
- VerifyPeer: func(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) {
- return &advancedtls.VerificationResults{}, nil
+ AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
+ return &advancedtls.PostHandshakeVerificationResults{}, nil
},
RootOptions: advancedtls.RootCertificateOptions{
RootProvider: clientRootProvider,
},
- VType: advancedtls.CertVerification,
+ VerificationType: advancedtls.CertVerification,
}
ta, err := advancedtls.NewClientCreds(options)
if err != nil {
@@ -155,14 +171,68 @@ func LoadClientTLSHTTP(clientCertFile string) *tls.Config {
}
}
-func (a Authenticator) Authenticate(params *advancedtls.VerificationFuncParams) (*advancedtls.VerificationResults, error) {
+func (a Authenticator) Authenticate(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
if a.AllowedWildcardDomain != "" && strings.HasSuffix(params.Leaf.Subject.CommonName, a.AllowedWildcardDomain) {
- return &advancedtls.VerificationResults{}, nil
+ return &advancedtls.PostHandshakeVerificationResults{}, nil
}
if _, ok := a.AllowedCommonNames[params.Leaf.Subject.CommonName]; ok {
- return &advancedtls.VerificationResults{}, nil
+ return &advancedtls.PostHandshakeVerificationResults{}, nil
}
err := fmt.Errorf("Authenticate: invalid subject client common name: %s", params.Leaf.Subject.CommonName)
glog.Error(err)
return nil, err
}
+
+func FixTlsConfig(viper *util.ViperProxy, config *tls.Config) error {
+ var err error
+ config.MinVersion, err = TlsVersionByName(viper.GetString("tls.min_version"))
+ if err != nil {
+ return err
+ }
+ config.MaxVersion, err = TlsVersionByName(viper.GetString("tls.max_version"))
+ if err != nil {
+ return err
+ }
+ config.CipherSuites, err = TlsCipherSuiteByNames(viper.GetString("tls.cipher_suites"))
+ return err
+}
+
+func TlsVersionByName(name string) (uint16, error) {
+ switch name {
+ case "":
+ return 0, nil
+ case "SSLv3":
+ return tls.VersionSSL30, nil
+ case "TLS 1.0":
+ return tls.VersionTLS10, nil
+ case "TLS 1.1":
+ return tls.VersionTLS11, nil
+ case "TLS 1.2":
+ return tls.VersionTLS12, nil
+ case "TLS 1.3":
+ return tls.VersionTLS13, nil
+ default:
+ return 0, fmt.Errorf("invalid tls version %s", name)
+ }
+}
+
+func TlsCipherSuiteByNames(cipherSuiteNames string) ([]uint16, error) {
+ cipherSuiteNames = strings.TrimSpace(cipherSuiteNames)
+ if cipherSuiteNames == "" {
+ return nil, nil
+ }
+ names := strings.Split(cipherSuiteNames, ",")
+ cipherSuites := tls.CipherSuites()
+ cipherIds := make([]uint16, 0, len(names))
+ for _, name := range names {
+ name = strings.TrimSpace(name)
+ index := slices.IndexFunc(cipherSuites, func(suite *tls.CipherSuite) bool {
+ return name == suite.Name
+ })
+ if index == -1 {
+ return nil, fmt.Errorf("invalid tls cipher suite name %s", name)
+ }
+ cipherIds = append(cipherIds, cipherSuites[index].ID)
+ }
+ return cipherIds, nil
+}