diff options
Diffstat (limited to 'weed/util')
| -rw-r--r-- | weed/util/http/client/http_client.go | 201 | ||||
| -rw-r--r-- | weed/util/http/client/http_client_interface.go | 16 | ||||
| -rw-r--r-- | weed/util/http/client/http_client_name.go | 14 | ||||
| -rw-r--r-- | weed/util/http/client/http_client_name_string.go | 23 | ||||
| -rw-r--r-- | weed/util/http/client/http_client_opt.go | 18 | ||||
| -rw-r--r-- | weed/util/http/http_global_client_init.go | 27 | ||||
| -rw-r--r-- | weed/util/http/http_global_client_util.go (renamed from weed/util/http_util.go) | 57 |
7 files changed, 319 insertions, 37 deletions
diff --git a/weed/util/http/client/http_client.go b/weed/util/http/client/http_client.go new file mode 100644 index 000000000..d1d2f5c56 --- /dev/null +++ b/weed/util/http/client/http_client.go @@ -0,0 +1,201 @@ +package client + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + util "github.com/seaweedfs/seaweedfs/weed/util" + "github.com/spf13/viper" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" +) + +var ( + loadSecurityConfigOnce sync.Once +) + +type HTTPClient struct { + Client *http.Client + Transport *http.Transport + expectHttpsScheme bool +} + +func (httpClient *HTTPClient) Do(req *http.Request) (*http.Response, error) { + req.URL.Scheme = httpClient.GetHttpScheme() + return httpClient.Client.Do(req) +} + +func (httpClient *HTTPClient) Get(url string) (resp *http.Response, err error) { + url, err = httpClient.NormalizeHttpScheme(url) + if err != nil { + return nil, err + } + return httpClient.Client.Get(url) +} + +func (httpClient *HTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) { + url, err = httpClient.NormalizeHttpScheme(url) + if err != nil { + return nil, err + } + return httpClient.Client.Post(url, contentType, body) +} + +func (httpClient *HTTPClient) PostForm(url string, data url.Values) (resp *http.Response, err error) { + url, err = httpClient.NormalizeHttpScheme(url) + if err != nil { + return nil, err + } + return httpClient.Client.PostForm(url, data) +} + +func (httpClient *HTTPClient) Head(url string) (resp *http.Response, err error) { + url, err = httpClient.NormalizeHttpScheme(url) + if err != nil { + return nil, err + } + return httpClient.Client.Head(url) +} +func (httpClient *HTTPClient) CloseIdleConnections() { + httpClient.Client.CloseIdleConnections() +} + +func (httpClient *HTTPClient) GetClientTransport() *http.Transport { + return httpClient.Transport +} + +func (httpClient *HTTPClient) GetHttpScheme() string { + if httpClient.expectHttpsScheme { + return "https" + } + return "http" +} + +func (httpClient *HTTPClient) NormalizeHttpScheme(rawURL string) (string, error) { + expectedScheme := httpClient.GetHttpScheme() + + if !(strings.HasPrefix(rawURL, "http://") || strings.HasPrefix(rawURL, "https://")) { + return expectedScheme + "://" + rawURL, nil + } + + parsedURL, err := url.Parse(rawURL) + if err != nil { + return "", err + } + + if expectedScheme != parsedURL.Scheme { + parsedURL.Scheme = expectedScheme + } + return parsedURL.String(), nil +} + +func NewHttpClient(clientName ClientName, opts ...HttpClientOpt) (*HTTPClient, error) { + httpClient := HTTPClient{} + httpClient.expectHttpsScheme = checkIsHttpsClientEnabled(clientName) + var tlsConfig *tls.Config = nil + + if httpClient.expectHttpsScheme { + clientCertPair, err := getClientCertPair(clientName) + if err != nil { + return nil, err + } + + clientCaCert, clientCaCertName, err := getClientCaCert(clientName) + if err != nil { + return nil, err + } + + if clientCertPair != nil || len(clientCaCert) != 0 { + caCertPool, err := createHTTPClientCertPool(clientCaCert, clientCaCertName) + if err != nil { + return nil, err + } + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{}, + RootCAs: caCertPool, + InsecureSkipVerify: false, + } + + if clientCertPair != nil { + tlsConfig.Certificates = append(tlsConfig.Certificates, *clientCertPair) + } + } + } + + httpClient.Transport = &http.Transport{ + MaxIdleConns: 1024, + MaxIdleConnsPerHost: 1024, + TLSClientConfig: tlsConfig, + } + httpClient.Client = &http.Client{ + Transport: httpClient.Transport, + } + + for _, opt := range opts { + opt(&httpClient) + } + return &httpClient, nil +} + +func getStringOptionFromSecurityConfiguration(clientName ClientName, stringOptionName string) string { + util.LoadSecurityConfiguration() + return viper.GetString(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), stringOptionName)) +} + +func getBoolOptionFromSecurityConfiguration(clientName ClientName, boolOptionName string) bool { + util.LoadSecurityConfiguration() + return viper.GetBool(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), boolOptionName)) +} + +func checkIsHttpsClientEnabled(clientName ClientName) bool { + return getBoolOptionFromSecurityConfiguration(clientName, "enabled") +} + +func getFileContentFromSecurityConfiguration(clientName ClientName, fileType string) ([]byte, string, error) { + if fileName := getStringOptionFromSecurityConfiguration(clientName, fileType); fileName != "" { + fileContent, err := os.ReadFile(fileName) + if err != nil { + return nil, fileName, err + } + return fileContent, fileName, err + } + return nil, "", nil +} + +func getClientCertPair(clientName ClientName) (*tls.Certificate, error) { + certFileName := getStringOptionFromSecurityConfiguration(clientName, "cert") + keyFileName := getStringOptionFromSecurityConfiguration(clientName, "key") + if certFileName == "" && keyFileName == "" { + return nil, nil + } + if certFileName != "" && keyFileName != "" { + clientCert, err := tls.LoadX509KeyPair(certFileName, keyFileName) + if err != nil { + return nil, fmt.Errorf("error loading client certificate and key: %s", err) + } + return &clientCert, nil + } + return nil, fmt.Errorf("error loading key pair: key `%s` and certificate `%s`", keyFileName, certFileName) +} + +func getClientCaCert(clientName ClientName) ([]byte, string, error) { + return getFileContentFromSecurityConfiguration(clientName, "ca") +} + +func createHTTPClientCertPool(certContent []byte, fileName string) (*x509.CertPool, error) { + certPool := x509.NewCertPool() + if len(certContent) == 0 { + return certPool, nil + } + + ok := certPool.AppendCertsFromPEM(certContent) + if !ok { + return nil, fmt.Errorf("error processing certificate in %s", fileName) + } + return certPool, nil +} diff --git a/weed/util/http/client/http_client_interface.go b/weed/util/http/client/http_client_interface.go new file mode 100644 index 000000000..7a2d43360 --- /dev/null +++ b/weed/util/http/client/http_client_interface.go @@ -0,0 +1,16 @@ +package client + +import ( + "io" + "net/http" + "net/url" +) + +type HTTPClientInterface interface { + Do(req *http.Request) (*http.Response, error) + Get(url string) (resp *http.Response, err error) + Post(url, contentType string, body io.Reader) (resp *http.Response, err error) + PostForm(url string, data url.Values) (resp *http.Response, err error) + Head(url string) (resp *http.Response, err error) + CloseIdleConnections() +} diff --git a/weed/util/http/client/http_client_name.go b/weed/util/http/client/http_client_name.go new file mode 100644 index 000000000..aedaebbc6 --- /dev/null +++ b/weed/util/http/client/http_client_name.go @@ -0,0 +1,14 @@ +package client + +import "strings" + +type ClientName int + +//go:generate stringer -type=ClientName -output=http_client_name_string.go +const ( + Client ClientName = iota +) + +func (name *ClientName) LowerCaseString() string { + return strings.ToLower(name.String()) +} diff --git a/weed/util/http/client/http_client_name_string.go b/weed/util/http/client/http_client_name_string.go new file mode 100644 index 000000000..652fcdaac --- /dev/null +++ b/weed/util/http/client/http_client_name_string.go @@ -0,0 +1,23 @@ +// Code generated by "stringer -type=ClientName -output=http_client_name_string.go"; DO NOT EDIT. + +package client + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Client-0] +} + +const _ClientName_name = "Client" + +var _ClientName_index = [...]uint8{0, 6} + +func (i ClientName) String() string { + if i < 0 || i >= ClientName(len(_ClientName_index)-1) { + return "ClientName(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ClientName_name[_ClientName_index[i]:_ClientName_index[i+1]] +} diff --git a/weed/util/http/client/http_client_opt.go b/weed/util/http/client/http_client_opt.go new file mode 100644 index 000000000..1ff9d533d --- /dev/null +++ b/weed/util/http/client/http_client_opt.go @@ -0,0 +1,18 @@ +package client + +import ( + "net" + "time" +) + +type HttpClientOpt = func(clientCfg *HTTPClient) + +func AddDialContext(httpClient *HTTPClient) { + dialContext := (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 10 * time.Second, + }).DialContext + + httpClient.Transport.DialContext = dialContext + httpClient.Client.Transport = httpClient.Transport +} diff --git a/weed/util/http/http_global_client_init.go b/weed/util/http/http_global_client_init.go new file mode 100644 index 000000000..0dcb05cfd --- /dev/null +++ b/weed/util/http/http_global_client_init.go @@ -0,0 +1,27 @@ +package http + +import ( + "github.com/seaweedfs/seaweedfs/weed/glog" + util_http_client "github.com/seaweedfs/seaweedfs/weed/util/http/client" +) + +var ( + globalHttpClient *util_http_client.HTTPClient +) + +func NewGlobalHttpClient(opt ...util_http_client.HttpClientOpt) (*util_http_client.HTTPClient, error) { + return util_http_client.NewHttpClient(util_http_client.Client, opt...) +} + +func GetGlobalHttpClient() *util_http_client.HTTPClient { + return globalHttpClient +} + +func InitGlobalHttpClient() { + var err error + + globalHttpClient, err = NewGlobalHttpClient() + if err != nil { + glog.Fatalf("error init global http client: %v", err) + } +} diff --git a/weed/util/http_util.go b/weed/util/http/http_global_client_util.go index 837b3ccb6..c3931a790 100644 --- a/weed/util/http_util.go +++ b/weed/util/http/http_global_client_util.go @@ -1,4 +1,4 @@ -package util +package http import ( "compress/gzip" @@ -6,6 +6,7 @@ import ( "errors" "fmt" "github.com/seaweedfs/seaweedfs/weed/util/mem" + "github.com/seaweedfs/seaweedfs/weed/util" "io" "net/http" "net/url" @@ -15,23 +16,8 @@ import ( "github.com/seaweedfs/seaweedfs/weed/glog" ) -var ( - client *http.Client - Transport *http.Transport -) - -func init() { - Transport = &http.Transport{ - MaxIdleConns: 1024, - MaxIdleConnsPerHost: 1024, - } - client = &http.Client{ - Transport: Transport, - } -} - func Post(url string, values url.Values) ([]byte, error) { - r, err := client.PostForm(url, values) + r, err := GetGlobalHttpClient().PostForm(url, values) if err != nil { return nil, err } @@ -64,7 +50,7 @@ func GetAuthenticated(url, jwt string) ([]byte, bool, error) { maybeAddAuth(request, jwt) request.Header.Add("Accept-Encoding", "gzip") - response, err := client.Do(request) + response, err := GetGlobalHttpClient().Do(request) if err != nil { return nil, true, err } @@ -94,7 +80,7 @@ func GetAuthenticated(url, jwt string) ([]byte, bool, error) { } func Head(url string) (http.Header, error) { - r, err := client.Head(url) + r, err := GetGlobalHttpClient().Head(url) if err != nil { return nil, err } @@ -117,7 +103,7 @@ func Delete(url string, jwt string) error { if err != nil { return err } - resp, e := client.Do(req) + resp, e := GetGlobalHttpClient().Do(req) if e != nil { return e } @@ -145,7 +131,7 @@ func DeleteProxied(url string, jwt string) (body []byte, httpStatus int, err err if err != nil { return } - resp, err := client.Do(req) + resp, err := GetGlobalHttpClient().Do(req) if err != nil { return } @@ -159,7 +145,7 @@ func DeleteProxied(url string, jwt string) (body []byte, httpStatus int, err err } func GetBufferStream(url string, values url.Values, allocatedBytes []byte, eachBuffer func([]byte)) error { - r, err := client.PostForm(url, values) + r, err := GetGlobalHttpClient().PostForm(url, values) if err != nil { return err } @@ -182,7 +168,7 @@ func GetBufferStream(url string, values url.Values, allocatedBytes []byte, eachB } func GetUrlStream(url string, values url.Values, readFn func(io.Reader) error) error { - r, err := client.PostForm(url, values) + r, err := GetGlobalHttpClient().PostForm(url, values) if err != nil { return err } @@ -201,7 +187,7 @@ func DownloadFile(fileUrl string, jwt string) (filename string, header http.Head maybeAddAuth(req, jwt) - response, err := client.Do(req) + response, err := GetGlobalHttpClient().Do(req) if err != nil { return "", nil, nil, err } @@ -219,14 +205,11 @@ func DownloadFile(fileUrl string, jwt string) (filename string, header http.Head } func Do(req *http.Request) (resp *http.Response, err error) { - return client.Do(req) + return GetGlobalHttpClient().Do(req) } -func NormalizeUrl(url string) string { - if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { - return url - } - return "http://" + url +func NormalizeUrl(url string) (string, error) { + return GetGlobalHttpClient().NormalizeHttpScheme(url) } func ReadUrl(fileUrl string, cipherKey []byte, isContentCompressed bool, isFullChunk bool, offset int64, size int, buf []byte) (int64, error) { @@ -249,7 +232,7 @@ func ReadUrl(fileUrl string, cipherKey []byte, isContentCompressed bool, isFullC req.Header.Set("Accept-Encoding", "gzip") } - r, err := client.Do(req) + r, err := GetGlobalHttpClient().Do(req) if err != nil { return 0, err } @@ -322,7 +305,7 @@ func ReadUrlAsStreamAuthenticated(fileUrl, jwt string, cipherKey []byte, isConte req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", offset, offset+int64(size)-1)) } - r, err := client.Do(req) + r, err := GetGlobalHttpClient().Do(req) if err != nil { return true, err } @@ -368,12 +351,12 @@ func readEncryptedUrl(fileUrl, jwt string, cipherKey []byte, isContentCompressed if err != nil { return retryable, fmt.Errorf("fetch %s: %v", fileUrl, err) } - decryptedData, err := Decrypt(encryptedData, CipherKey(cipherKey)) + decryptedData, err := util.Decrypt(encryptedData, util.CipherKey(cipherKey)) if err != nil { return false, fmt.Errorf("decrypt %s: %v", fileUrl, err) } if isContentCompressed { - decryptedData, err = DecompressData(decryptedData) + decryptedData, err = util.DecompressData(decryptedData) if err != nil { glog.V(0).Infof("unzip decrypt %s: %v", fileUrl, err) } @@ -403,7 +386,7 @@ func ReadUrlAsReaderCloser(fileUrl string, jwt string, rangeHeader string) (*htt maybeAddAuth(req, jwt) - r, err := client.Do(req) + r, err := GetGlobalHttpClient().Do(req) if err != nil { return nil, nil, err } @@ -463,7 +446,7 @@ func RetriedFetchChunkData(buffer []byte, urlStrings []string, cipherKey []byte, var shouldRetry bool - for waitTime := time.Second; waitTime < RetryWaitTime; waitTime += waitTime / 2 { + for waitTime := time.Second; waitTime < util.RetryWaitTime; waitTime += waitTime / 2 { for _, urlString := range urlStrings { n = 0 if strings.Contains(urlString, "%") { @@ -494,4 +477,4 @@ func RetriedFetchChunkData(buffer []byte, urlStrings []string, cipherKey []byte, return n, err -} +}
\ No newline at end of file |
