diff options
| author | vadimartynov <166398828+vadimartynov@users.noreply.github.com> | 2024-07-17 09:14:09 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-16 23:14:09 -0700 |
| commit | 86d92a42b4861d4bb05c58fea9db84d960995545 (patch) | |
| tree | b3b8cefc07fe3d10c0dc0c69120a9019584bd60a /weed/util/http/client | |
| parent | c6dec11ea556b8be648f372dfa5cbd074c9f631b (diff) | |
| download | seaweedfs-86d92a42b4861d4bb05c58fea9db84d960995545.tar.xz seaweedfs-86d92a42b4861d4bb05c58fea9db84d960995545.zip | |
Added tls for http clients (#5766)
* Added global http client
* Added Do func for global http client
* Changed the code to use the global http client
* Fix http client in volume uploader
* Fixed pkg name
* Fixed http util funcs
* Fixed http client for bench_filer_upload
* Fixed http client for stress_filer_upload
* Fixed http client for filer_server_handlers_proxy
* Fixed http client for command_fs_merge_volumes
* Fixed http client for command_fs_merge_volumes and command_volume_fsck
* Fixed http client for s3api_server
* Added init global client for main funcs
* Rename global_client to client
* Changed:
- fixed NewHttpClient;
- added CheckIsHttpsClientEnabled func
- updated security.toml in scaffold
* Reduce the visibility of some functions in the util/http/client pkg
* Added the loadSecurityConfig function
* Use util.LoadSecurityConfiguration() in NewHttpClient func
Diffstat (limited to 'weed/util/http/client')
| -rw-r--r-- | weed/util/http/client/http_client.go | 201 | ||||
| -rw-r--r-- | weed/util/http/client/http_client_interface.go | 16 | ||||
| -rw-r--r-- | weed/util/http/client/http_client_name.go | 14 | ||||
| -rw-r--r-- | weed/util/http/client/http_client_name_string.go | 23 | ||||
| -rw-r--r-- | weed/util/http/client/http_client_opt.go | 18 |
5 files changed, 272 insertions, 0 deletions
diff --git a/weed/util/http/client/http_client.go b/weed/util/http/client/http_client.go new file mode 100644 index 000000000..d1d2f5c56 --- /dev/null +++ b/weed/util/http/client/http_client.go @@ -0,0 +1,201 @@ +package client + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + util "github.com/seaweedfs/seaweedfs/weed/util" + "github.com/spf13/viper" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" +) + +var ( + loadSecurityConfigOnce sync.Once +) + +type HTTPClient struct { + Client *http.Client + Transport *http.Transport + expectHttpsScheme bool +} + +func (httpClient *HTTPClient) Do(req *http.Request) (*http.Response, error) { + req.URL.Scheme = httpClient.GetHttpScheme() + return httpClient.Client.Do(req) +} + +func (httpClient *HTTPClient) Get(url string) (resp *http.Response, err error) { + url, err = httpClient.NormalizeHttpScheme(url) + if err != nil { + return nil, err + } + return httpClient.Client.Get(url) +} + +func (httpClient *HTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) { + url, err = httpClient.NormalizeHttpScheme(url) + if err != nil { + return nil, err + } + return httpClient.Client.Post(url, contentType, body) +} + +func (httpClient *HTTPClient) PostForm(url string, data url.Values) (resp *http.Response, err error) { + url, err = httpClient.NormalizeHttpScheme(url) + if err != nil { + return nil, err + } + return httpClient.Client.PostForm(url, data) +} + +func (httpClient *HTTPClient) Head(url string) (resp *http.Response, err error) { + url, err = httpClient.NormalizeHttpScheme(url) + if err != nil { + return nil, err + } + return httpClient.Client.Head(url) +} +func (httpClient *HTTPClient) CloseIdleConnections() { + httpClient.Client.CloseIdleConnections() +} + +func (httpClient *HTTPClient) GetClientTransport() *http.Transport { + return httpClient.Transport +} + +func (httpClient *HTTPClient) GetHttpScheme() string { + if httpClient.expectHttpsScheme { + return "https" + } + return "http" +} + +func (httpClient *HTTPClient) NormalizeHttpScheme(rawURL string) (string, error) { + expectedScheme := httpClient.GetHttpScheme() + + if !(strings.HasPrefix(rawURL, "http://") || strings.HasPrefix(rawURL, "https://")) { + return expectedScheme + "://" + rawURL, nil + } + + parsedURL, err := url.Parse(rawURL) + if err != nil { + return "", err + } + + if expectedScheme != parsedURL.Scheme { + parsedURL.Scheme = expectedScheme + } + return parsedURL.String(), nil +} + +func NewHttpClient(clientName ClientName, opts ...HttpClientOpt) (*HTTPClient, error) { + httpClient := HTTPClient{} + httpClient.expectHttpsScheme = checkIsHttpsClientEnabled(clientName) + var tlsConfig *tls.Config = nil + + if httpClient.expectHttpsScheme { + clientCertPair, err := getClientCertPair(clientName) + if err != nil { + return nil, err + } + + clientCaCert, clientCaCertName, err := getClientCaCert(clientName) + if err != nil { + return nil, err + } + + if clientCertPair != nil || len(clientCaCert) != 0 { + caCertPool, err := createHTTPClientCertPool(clientCaCert, clientCaCertName) + if err != nil { + return nil, err + } + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{}, + RootCAs: caCertPool, + InsecureSkipVerify: false, + } + + if clientCertPair != nil { + tlsConfig.Certificates = append(tlsConfig.Certificates, *clientCertPair) + } + } + } + + httpClient.Transport = &http.Transport{ + MaxIdleConns: 1024, + MaxIdleConnsPerHost: 1024, + TLSClientConfig: tlsConfig, + } + httpClient.Client = &http.Client{ + Transport: httpClient.Transport, + } + + for _, opt := range opts { + opt(&httpClient) + } + return &httpClient, nil +} + +func getStringOptionFromSecurityConfiguration(clientName ClientName, stringOptionName string) string { + util.LoadSecurityConfiguration() + return viper.GetString(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), stringOptionName)) +} + +func getBoolOptionFromSecurityConfiguration(clientName ClientName, boolOptionName string) bool { + util.LoadSecurityConfiguration() + return viper.GetBool(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), boolOptionName)) +} + +func checkIsHttpsClientEnabled(clientName ClientName) bool { + return getBoolOptionFromSecurityConfiguration(clientName, "enabled") +} + +func getFileContentFromSecurityConfiguration(clientName ClientName, fileType string) ([]byte, string, error) { + if fileName := getStringOptionFromSecurityConfiguration(clientName, fileType); fileName != "" { + fileContent, err := os.ReadFile(fileName) + if err != nil { + return nil, fileName, err + } + return fileContent, fileName, err + } + return nil, "", nil +} + +func getClientCertPair(clientName ClientName) (*tls.Certificate, error) { + certFileName := getStringOptionFromSecurityConfiguration(clientName, "cert") + keyFileName := getStringOptionFromSecurityConfiguration(clientName, "key") + if certFileName == "" && keyFileName == "" { + return nil, nil + } + if certFileName != "" && keyFileName != "" { + clientCert, err := tls.LoadX509KeyPair(certFileName, keyFileName) + if err != nil { + return nil, fmt.Errorf("error loading client certificate and key: %s", err) + } + return &clientCert, nil + } + return nil, fmt.Errorf("error loading key pair: key `%s` and certificate `%s`", keyFileName, certFileName) +} + +func getClientCaCert(clientName ClientName) ([]byte, string, error) { + return getFileContentFromSecurityConfiguration(clientName, "ca") +} + +func createHTTPClientCertPool(certContent []byte, fileName string) (*x509.CertPool, error) { + certPool := x509.NewCertPool() + if len(certContent) == 0 { + return certPool, nil + } + + ok := certPool.AppendCertsFromPEM(certContent) + if !ok { + return nil, fmt.Errorf("error processing certificate in %s", fileName) + } + return certPool, nil +} diff --git a/weed/util/http/client/http_client_interface.go b/weed/util/http/client/http_client_interface.go new file mode 100644 index 000000000..7a2d43360 --- /dev/null +++ b/weed/util/http/client/http_client_interface.go @@ -0,0 +1,16 @@ +package client + +import ( + "io" + "net/http" + "net/url" +) + +type HTTPClientInterface interface { + Do(req *http.Request) (*http.Response, error) + Get(url string) (resp *http.Response, err error) + Post(url, contentType string, body io.Reader) (resp *http.Response, err error) + PostForm(url string, data url.Values) (resp *http.Response, err error) + Head(url string) (resp *http.Response, err error) + CloseIdleConnections() +} diff --git a/weed/util/http/client/http_client_name.go b/weed/util/http/client/http_client_name.go new file mode 100644 index 000000000..aedaebbc6 --- /dev/null +++ b/weed/util/http/client/http_client_name.go @@ -0,0 +1,14 @@ +package client + +import "strings" + +type ClientName int + +//go:generate stringer -type=ClientName -output=http_client_name_string.go +const ( + Client ClientName = iota +) + +func (name *ClientName) LowerCaseString() string { + return strings.ToLower(name.String()) +} diff --git a/weed/util/http/client/http_client_name_string.go b/weed/util/http/client/http_client_name_string.go new file mode 100644 index 000000000..652fcdaac --- /dev/null +++ b/weed/util/http/client/http_client_name_string.go @@ -0,0 +1,23 @@ +// Code generated by "stringer -type=ClientName -output=http_client_name_string.go"; DO NOT EDIT. + +package client + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Client-0] +} + +const _ClientName_name = "Client" + +var _ClientName_index = [...]uint8{0, 6} + +func (i ClientName) String() string { + if i < 0 || i >= ClientName(len(_ClientName_index)-1) { + return "ClientName(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ClientName_name[_ClientName_index[i]:_ClientName_index[i+1]] +} diff --git a/weed/util/http/client/http_client_opt.go b/weed/util/http/client/http_client_opt.go new file mode 100644 index 000000000..1ff9d533d --- /dev/null +++ b/weed/util/http/client/http_client_opt.go @@ -0,0 +1,18 @@ +package client + +import ( + "net" + "time" +) + +type HttpClientOpt = func(clientCfg *HTTPClient) + +func AddDialContext(httpClient *HTTPClient) { + dialContext := (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 10 * time.Second, + }).DialContext + + httpClient.Transport.DialContext = dialContext + httpClient.Client.Transport = httpClient.Transport +} |
