aboutsummaryrefslogtreecommitdiff
path: root/weed/util/http/client
diff options
context:
space:
mode:
authorvadimartynov <166398828+vadimartynov@users.noreply.github.com>2024-07-17 09:14:09 +0300
committerGitHub <noreply@github.com>2024-07-16 23:14:09 -0700
commit86d92a42b4861d4bb05c58fea9db84d960995545 (patch)
treeb3b8cefc07fe3d10c0dc0c69120a9019584bd60a /weed/util/http/client
parentc6dec11ea556b8be648f372dfa5cbd074c9f631b (diff)
downloadseaweedfs-86d92a42b4861d4bb05c58fea9db84d960995545.tar.xz
seaweedfs-86d92a42b4861d4bb05c58fea9db84d960995545.zip
Added tls for http clients (#5766)
* Added global http client * Added Do func for global http client * Changed the code to use the global http client * Fix http client in volume uploader * Fixed pkg name * Fixed http util funcs * Fixed http client for bench_filer_upload * Fixed http client for stress_filer_upload * Fixed http client for filer_server_handlers_proxy * Fixed http client for command_fs_merge_volumes * Fixed http client for command_fs_merge_volumes and command_volume_fsck * Fixed http client for s3api_server * Added init global client for main funcs * Rename global_client to client * Changed: - fixed NewHttpClient; - added CheckIsHttpsClientEnabled func - updated security.toml in scaffold * Reduce the visibility of some functions in the util/http/client pkg * Added the loadSecurityConfig function * Use util.LoadSecurityConfiguration() in NewHttpClient func
Diffstat (limited to 'weed/util/http/client')
-rw-r--r--weed/util/http/client/http_client.go201
-rw-r--r--weed/util/http/client/http_client_interface.go16
-rw-r--r--weed/util/http/client/http_client_name.go14
-rw-r--r--weed/util/http/client/http_client_name_string.go23
-rw-r--r--weed/util/http/client/http_client_opt.go18
5 files changed, 272 insertions, 0 deletions
diff --git a/weed/util/http/client/http_client.go b/weed/util/http/client/http_client.go
new file mode 100644
index 000000000..d1d2f5c56
--- /dev/null
+++ b/weed/util/http/client/http_client.go
@@ -0,0 +1,201 @@
+package client
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ util "github.com/seaweedfs/seaweedfs/weed/util"
+ "github.com/spf13/viper"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+)
+
+var (
+ loadSecurityConfigOnce sync.Once
+)
+
+type HTTPClient struct {
+ Client *http.Client
+ Transport *http.Transport
+ expectHttpsScheme bool
+}
+
+func (httpClient *HTTPClient) Do(req *http.Request) (*http.Response, error) {
+ req.URL.Scheme = httpClient.GetHttpScheme()
+ return httpClient.Client.Do(req)
+}
+
+func (httpClient *HTTPClient) Get(url string) (resp *http.Response, err error) {
+ url, err = httpClient.NormalizeHttpScheme(url)
+ if err != nil {
+ return nil, err
+ }
+ return httpClient.Client.Get(url)
+}
+
+func (httpClient *HTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) {
+ url, err = httpClient.NormalizeHttpScheme(url)
+ if err != nil {
+ return nil, err
+ }
+ return httpClient.Client.Post(url, contentType, body)
+}
+
+func (httpClient *HTTPClient) PostForm(url string, data url.Values) (resp *http.Response, err error) {
+ url, err = httpClient.NormalizeHttpScheme(url)
+ if err != nil {
+ return nil, err
+ }
+ return httpClient.Client.PostForm(url, data)
+}
+
+func (httpClient *HTTPClient) Head(url string) (resp *http.Response, err error) {
+ url, err = httpClient.NormalizeHttpScheme(url)
+ if err != nil {
+ return nil, err
+ }
+ return httpClient.Client.Head(url)
+}
+func (httpClient *HTTPClient) CloseIdleConnections() {
+ httpClient.Client.CloseIdleConnections()
+}
+
+func (httpClient *HTTPClient) GetClientTransport() *http.Transport {
+ return httpClient.Transport
+}
+
+func (httpClient *HTTPClient) GetHttpScheme() string {
+ if httpClient.expectHttpsScheme {
+ return "https"
+ }
+ return "http"
+}
+
+func (httpClient *HTTPClient) NormalizeHttpScheme(rawURL string) (string, error) {
+ expectedScheme := httpClient.GetHttpScheme()
+
+ if !(strings.HasPrefix(rawURL, "http://") || strings.HasPrefix(rawURL, "https://")) {
+ return expectedScheme + "://" + rawURL, nil
+ }
+
+ parsedURL, err := url.Parse(rawURL)
+ if err != nil {
+ return "", err
+ }
+
+ if expectedScheme != parsedURL.Scheme {
+ parsedURL.Scheme = expectedScheme
+ }
+ return parsedURL.String(), nil
+}
+
+func NewHttpClient(clientName ClientName, opts ...HttpClientOpt) (*HTTPClient, error) {
+ httpClient := HTTPClient{}
+ httpClient.expectHttpsScheme = checkIsHttpsClientEnabled(clientName)
+ var tlsConfig *tls.Config = nil
+
+ if httpClient.expectHttpsScheme {
+ clientCertPair, err := getClientCertPair(clientName)
+ if err != nil {
+ return nil, err
+ }
+
+ clientCaCert, clientCaCertName, err := getClientCaCert(clientName)
+ if err != nil {
+ return nil, err
+ }
+
+ if clientCertPair != nil || len(clientCaCert) != 0 {
+ caCertPool, err := createHTTPClientCertPool(clientCaCert, clientCaCertName)
+ if err != nil {
+ return nil, err
+ }
+
+ tlsConfig = &tls.Config{
+ Certificates: []tls.Certificate{},
+ RootCAs: caCertPool,
+ InsecureSkipVerify: false,
+ }
+
+ if clientCertPair != nil {
+ tlsConfig.Certificates = append(tlsConfig.Certificates, *clientCertPair)
+ }
+ }
+ }
+
+ httpClient.Transport = &http.Transport{
+ MaxIdleConns: 1024,
+ MaxIdleConnsPerHost: 1024,
+ TLSClientConfig: tlsConfig,
+ }
+ httpClient.Client = &http.Client{
+ Transport: httpClient.Transport,
+ }
+
+ for _, opt := range opts {
+ opt(&httpClient)
+ }
+ return &httpClient, nil
+}
+
+func getStringOptionFromSecurityConfiguration(clientName ClientName, stringOptionName string) string {
+ util.LoadSecurityConfiguration()
+ return viper.GetString(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), stringOptionName))
+}
+
+func getBoolOptionFromSecurityConfiguration(clientName ClientName, boolOptionName string) bool {
+ util.LoadSecurityConfiguration()
+ return viper.GetBool(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), boolOptionName))
+}
+
+func checkIsHttpsClientEnabled(clientName ClientName) bool {
+ return getBoolOptionFromSecurityConfiguration(clientName, "enabled")
+}
+
+func getFileContentFromSecurityConfiguration(clientName ClientName, fileType string) ([]byte, string, error) {
+ if fileName := getStringOptionFromSecurityConfiguration(clientName, fileType); fileName != "" {
+ fileContent, err := os.ReadFile(fileName)
+ if err != nil {
+ return nil, fileName, err
+ }
+ return fileContent, fileName, err
+ }
+ return nil, "", nil
+}
+
+func getClientCertPair(clientName ClientName) (*tls.Certificate, error) {
+ certFileName := getStringOptionFromSecurityConfiguration(clientName, "cert")
+ keyFileName := getStringOptionFromSecurityConfiguration(clientName, "key")
+ if certFileName == "" && keyFileName == "" {
+ return nil, nil
+ }
+ if certFileName != "" && keyFileName != "" {
+ clientCert, err := tls.LoadX509KeyPair(certFileName, keyFileName)
+ if err != nil {
+ return nil, fmt.Errorf("error loading client certificate and key: %s", err)
+ }
+ return &clientCert, nil
+ }
+ return nil, fmt.Errorf("error loading key pair: key `%s` and certificate `%s`", keyFileName, certFileName)
+}
+
+func getClientCaCert(clientName ClientName) ([]byte, string, error) {
+ return getFileContentFromSecurityConfiguration(clientName, "ca")
+}
+
+func createHTTPClientCertPool(certContent []byte, fileName string) (*x509.CertPool, error) {
+ certPool := x509.NewCertPool()
+ if len(certContent) == 0 {
+ return certPool, nil
+ }
+
+ ok := certPool.AppendCertsFromPEM(certContent)
+ if !ok {
+ return nil, fmt.Errorf("error processing certificate in %s", fileName)
+ }
+ return certPool, nil
+}
diff --git a/weed/util/http/client/http_client_interface.go b/weed/util/http/client/http_client_interface.go
new file mode 100644
index 000000000..7a2d43360
--- /dev/null
+++ b/weed/util/http/client/http_client_interface.go
@@ -0,0 +1,16 @@
+package client
+
+import (
+ "io"
+ "net/http"
+ "net/url"
+)
+
+type HTTPClientInterface interface {
+ Do(req *http.Request) (*http.Response, error)
+ Get(url string) (resp *http.Response, err error)
+ Post(url, contentType string, body io.Reader) (resp *http.Response, err error)
+ PostForm(url string, data url.Values) (resp *http.Response, err error)
+ Head(url string) (resp *http.Response, err error)
+ CloseIdleConnections()
+}
diff --git a/weed/util/http/client/http_client_name.go b/weed/util/http/client/http_client_name.go
new file mode 100644
index 000000000..aedaebbc6
--- /dev/null
+++ b/weed/util/http/client/http_client_name.go
@@ -0,0 +1,14 @@
+package client
+
+import "strings"
+
+type ClientName int
+
+//go:generate stringer -type=ClientName -output=http_client_name_string.go
+const (
+ Client ClientName = iota
+)
+
+func (name *ClientName) LowerCaseString() string {
+ return strings.ToLower(name.String())
+}
diff --git a/weed/util/http/client/http_client_name_string.go b/weed/util/http/client/http_client_name_string.go
new file mode 100644
index 000000000..652fcdaac
--- /dev/null
+++ b/weed/util/http/client/http_client_name_string.go
@@ -0,0 +1,23 @@
+// Code generated by "stringer -type=ClientName -output=http_client_name_string.go"; DO NOT EDIT.
+
+package client
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Client-0]
+}
+
+const _ClientName_name = "Client"
+
+var _ClientName_index = [...]uint8{0, 6}
+
+func (i ClientName) String() string {
+ if i < 0 || i >= ClientName(len(_ClientName_index)-1) {
+ return "ClientName(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _ClientName_name[_ClientName_index[i]:_ClientName_index[i+1]]
+}
diff --git a/weed/util/http/client/http_client_opt.go b/weed/util/http/client/http_client_opt.go
new file mode 100644
index 000000000..1ff9d533d
--- /dev/null
+++ b/weed/util/http/client/http_client_opt.go
@@ -0,0 +1,18 @@
+package client
+
+import (
+ "net"
+ "time"
+)
+
+type HttpClientOpt = func(clientCfg *HTTPClient)
+
+func AddDialContext(httpClient *HTTPClient) {
+ dialContext := (&net.Dialer{
+ Timeout: 10 * time.Second,
+ KeepAlive: 10 * time.Second,
+ }).DialContext
+
+ httpClient.Transport.DialContext = dialContext
+ httpClient.Client.Transport = httpClient.Transport
+}