aboutsummaryrefslogtreecommitdiff
path: root/weed/util/http/client/http_client.go
diff options
context:
space:
mode:
Diffstat (limited to 'weed/util/http/client/http_client.go')
-rw-r--r--weed/util/http/client/http_client.go201
1 files changed, 201 insertions, 0 deletions
diff --git a/weed/util/http/client/http_client.go b/weed/util/http/client/http_client.go
new file mode 100644
index 000000000..d1d2f5c56
--- /dev/null
+++ b/weed/util/http/client/http_client.go
@@ -0,0 +1,201 @@
+package client
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ util "github.com/seaweedfs/seaweedfs/weed/util"
+ "github.com/spf13/viper"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+)
+
+var (
+ loadSecurityConfigOnce sync.Once
+)
+
+type HTTPClient struct {
+ Client *http.Client
+ Transport *http.Transport
+ expectHttpsScheme bool
+}
+
+func (httpClient *HTTPClient) Do(req *http.Request) (*http.Response, error) {
+ req.URL.Scheme = httpClient.GetHttpScheme()
+ return httpClient.Client.Do(req)
+}
+
+func (httpClient *HTTPClient) Get(url string) (resp *http.Response, err error) {
+ url, err = httpClient.NormalizeHttpScheme(url)
+ if err != nil {
+ return nil, err
+ }
+ return httpClient.Client.Get(url)
+}
+
+func (httpClient *HTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) {
+ url, err = httpClient.NormalizeHttpScheme(url)
+ if err != nil {
+ return nil, err
+ }
+ return httpClient.Client.Post(url, contentType, body)
+}
+
+func (httpClient *HTTPClient) PostForm(url string, data url.Values) (resp *http.Response, err error) {
+ url, err = httpClient.NormalizeHttpScheme(url)
+ if err != nil {
+ return nil, err
+ }
+ return httpClient.Client.PostForm(url, data)
+}
+
+func (httpClient *HTTPClient) Head(url string) (resp *http.Response, err error) {
+ url, err = httpClient.NormalizeHttpScheme(url)
+ if err != nil {
+ return nil, err
+ }
+ return httpClient.Client.Head(url)
+}
+func (httpClient *HTTPClient) CloseIdleConnections() {
+ httpClient.Client.CloseIdleConnections()
+}
+
+func (httpClient *HTTPClient) GetClientTransport() *http.Transport {
+ return httpClient.Transport
+}
+
+func (httpClient *HTTPClient) GetHttpScheme() string {
+ if httpClient.expectHttpsScheme {
+ return "https"
+ }
+ return "http"
+}
+
+func (httpClient *HTTPClient) NormalizeHttpScheme(rawURL string) (string, error) {
+ expectedScheme := httpClient.GetHttpScheme()
+
+ if !(strings.HasPrefix(rawURL, "http://") || strings.HasPrefix(rawURL, "https://")) {
+ return expectedScheme + "://" + rawURL, nil
+ }
+
+ parsedURL, err := url.Parse(rawURL)
+ if err != nil {
+ return "", err
+ }
+
+ if expectedScheme != parsedURL.Scheme {
+ parsedURL.Scheme = expectedScheme
+ }
+ return parsedURL.String(), nil
+}
+
+func NewHttpClient(clientName ClientName, opts ...HttpClientOpt) (*HTTPClient, error) {
+ httpClient := HTTPClient{}
+ httpClient.expectHttpsScheme = checkIsHttpsClientEnabled(clientName)
+ var tlsConfig *tls.Config = nil
+
+ if httpClient.expectHttpsScheme {
+ clientCertPair, err := getClientCertPair(clientName)
+ if err != nil {
+ return nil, err
+ }
+
+ clientCaCert, clientCaCertName, err := getClientCaCert(clientName)
+ if err != nil {
+ return nil, err
+ }
+
+ if clientCertPair != nil || len(clientCaCert) != 0 {
+ caCertPool, err := createHTTPClientCertPool(clientCaCert, clientCaCertName)
+ if err != nil {
+ return nil, err
+ }
+
+ tlsConfig = &tls.Config{
+ Certificates: []tls.Certificate{},
+ RootCAs: caCertPool,
+ InsecureSkipVerify: false,
+ }
+
+ if clientCertPair != nil {
+ tlsConfig.Certificates = append(tlsConfig.Certificates, *clientCertPair)
+ }
+ }
+ }
+
+ httpClient.Transport = &http.Transport{
+ MaxIdleConns: 1024,
+ MaxIdleConnsPerHost: 1024,
+ TLSClientConfig: tlsConfig,
+ }
+ httpClient.Client = &http.Client{
+ Transport: httpClient.Transport,
+ }
+
+ for _, opt := range opts {
+ opt(&httpClient)
+ }
+ return &httpClient, nil
+}
+
+func getStringOptionFromSecurityConfiguration(clientName ClientName, stringOptionName string) string {
+ util.LoadSecurityConfiguration()
+ return viper.GetString(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), stringOptionName))
+}
+
+func getBoolOptionFromSecurityConfiguration(clientName ClientName, boolOptionName string) bool {
+ util.LoadSecurityConfiguration()
+ return viper.GetBool(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), boolOptionName))
+}
+
+func checkIsHttpsClientEnabled(clientName ClientName) bool {
+ return getBoolOptionFromSecurityConfiguration(clientName, "enabled")
+}
+
+func getFileContentFromSecurityConfiguration(clientName ClientName, fileType string) ([]byte, string, error) {
+ if fileName := getStringOptionFromSecurityConfiguration(clientName, fileType); fileName != "" {
+ fileContent, err := os.ReadFile(fileName)
+ if err != nil {
+ return nil, fileName, err
+ }
+ return fileContent, fileName, err
+ }
+ return nil, "", nil
+}
+
+func getClientCertPair(clientName ClientName) (*tls.Certificate, error) {
+ certFileName := getStringOptionFromSecurityConfiguration(clientName, "cert")
+ keyFileName := getStringOptionFromSecurityConfiguration(clientName, "key")
+ if certFileName == "" && keyFileName == "" {
+ return nil, nil
+ }
+ if certFileName != "" && keyFileName != "" {
+ clientCert, err := tls.LoadX509KeyPair(certFileName, keyFileName)
+ if err != nil {
+ return nil, fmt.Errorf("error loading client certificate and key: %s", err)
+ }
+ return &clientCert, nil
+ }
+ return nil, fmt.Errorf("error loading key pair: key `%s` and certificate `%s`", keyFileName, certFileName)
+}
+
+func getClientCaCert(clientName ClientName) ([]byte, string, error) {
+ return getFileContentFromSecurityConfiguration(clientName, "ca")
+}
+
+func createHTTPClientCertPool(certContent []byte, fileName string) (*x509.CertPool, error) {
+ certPool := x509.NewCertPool()
+ if len(certContent) == 0 {
+ return certPool, nil
+ }
+
+ ok := certPool.AppendCertsFromPEM(certContent)
+ if !ok {
+ return nil, fmt.Errorf("error processing certificate in %s", fileName)
+ }
+ return certPool, nil
+}