diff options
Diffstat (limited to 'weed/util')
| -rw-r--r-- | weed/util/bytes.go | 31 | ||||
| -rw-r--r-- | weed/util/cipher.go | 81 | ||||
| -rw-r--r-- | weed/util/compression.go | 10 | ||||
| -rw-r--r-- | weed/util/config.go | 21 | ||||
| -rw-r--r-- | weed/util/constants.go | 2 | ||||
| -rw-r--r-- | weed/util/file_util.go | 19 | ||||
| -rw-r--r-- | weed/util/grpc_client_server.go | 120 | ||||
| -rw-r--r-- | weed/util/http_util.go | 146 | ||||
| -rw-r--r-- | weed/util/httpdown/http_down.go | 395 | ||||
| -rw-r--r-- | weed/util/net_timeout.go | 7 | ||||
| -rw-r--r-- | weed/util/queue.go | 61 | ||||
| -rw-r--r-- | weed/util/queue_unbounded.go | 45 | ||||
| -rw-r--r-- | weed/util/queue_unbounded_test.go | 25 |
13 files changed, 796 insertions, 167 deletions
diff --git a/weed/util/bytes.go b/weed/util/bytes.go index dfa4ae665..9c7e5e2cb 100644 --- a/weed/util/bytes.go +++ b/weed/util/bytes.go @@ -1,5 +1,10 @@ package util +import ( + "crypto/md5" + "io" +) + // big endian func BytesToUint64(b []byte) (v uint64) { @@ -43,3 +48,29 @@ func Uint16toBytes(b []byte, v uint16) { func Uint8toBytes(b []byte, v uint8) { b[0] = byte(v) } + +// returns a 64 bit big int +func HashStringToLong(dir string) (v int64) { + h := md5.New() + io.WriteString(h, dir) + + b := h.Sum(nil) + + v += int64(b[0]) + v <<= 8 + v += int64(b[1]) + v <<= 8 + v += int64(b[2]) + v <<= 8 + v += int64(b[3]) + v <<= 8 + v += int64(b[4]) + v <<= 8 + v += int64(b[5]) + v <<= 8 + v += int64(b[6]) + v <<= 8 + v += int64(b[7]) + + return +} diff --git a/weed/util/cipher.go b/weed/util/cipher.go new file mode 100644 index 000000000..7bcb6559a --- /dev/null +++ b/weed/util/cipher.go @@ -0,0 +1,81 @@ +package util + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" + "fmt" + "io" + "io/ioutil" + + "github.com/chrislusf/seaweedfs/weed/glog" +) + +type CipherKey []byte + +func GenCipherKey() CipherKey { + key := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + glog.Fatalf("random key gen: %v", err) + } + return CipherKey(key) +} + +func Encrypt(plaintext []byte, key CipherKey) ([]byte, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + return gcm.Seal(nonce, nonce, plaintext, nil), nil +} + +func Decrypt(ciphertext []byte, key CipherKey) ([]byte, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return nil, errors.New("ciphertext too short") + } + + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + return gcm.Open(nil, nonce, ciphertext, nil) +} + +func EncryptReader(clearReader io.Reader) (cipherKey CipherKey, encryptedReader io.ReadCloser, clearDataLen, encryptedDataLen int, err error) { + clearData, err := ioutil.ReadAll(clearReader) + if err != nil { + err = fmt.Errorf("read raw input: %v", err) + return + } + clearDataLen = len(clearData) + cipherKey = GenCipherKey() + encryptedData, err := Encrypt(clearData, cipherKey) + if err != nil { + err = fmt.Errorf("encrypt input: %v", err) + return + } + encryptedDataLen = len(encryptedData) + encryptedReader = ioutil.NopCloser(bytes.NewReader(encryptedData)) + return +} diff --git a/weed/util/compression.go b/weed/util/compression.go index c6c9423e2..6072df632 100644 --- a/weed/util/compression.go +++ b/weed/util/compression.go @@ -60,7 +60,7 @@ func UnGzipData(input []byte) ([]byte, error) { // images switch ext { - case ".svg", ".bmp": + case ".svg", ".bmp", ".wav": return true, true } if strings.HasPrefix(mtype, "image/") { @@ -87,6 +87,14 @@ func UnGzipData(input []byte) ([]byte, error) { if strings.HasSuffix(mtype, "script") { return true, true } + + } + + if strings.HasPrefix(mtype, "audio/") { + switch strings.TrimPrefix(mtype, "audio/") { + case "wave", "wav", "x-wav", "x-pn-wav": + return true, true + } } return false, false diff --git a/weed/util/config.go b/weed/util/config.go index 84f146bc8..dfbfdbd82 100644 --- a/weed/util/config.go +++ b/weed/util/config.go @@ -1,17 +1,19 @@ package util import ( - "github.com/chrislusf/seaweedfs/weed/glog" + "strings" + "github.com/spf13/viper" + + "github.com/chrislusf/seaweedfs/weed/glog" ) type Configuration interface { GetString(key string) string GetBool(key string) bool GetInt(key string) int - GetInt64(key string) int64 - GetFloat64(key string) float64 GetStringSlice(key string) []string + SetDefault(key string, value interface{}) } func LoadConfiguration(configFileName string, required bool) (loaded bool) { @@ -28,10 +30,7 @@ func LoadConfiguration(configFileName string, required bool) (loaded bool) { glog.V(0).Infof("Reading %s: %v", viper.ConfigFileUsed(), err) if required { glog.Fatalf("Failed to load %s.toml file from current directory, or $HOME/.seaweedfs/, or /etc/seaweedfs/"+ - "\n\nPlease follow this example and add a filer.toml file to "+ - "current directory, or $HOME/.seaweedfs/, or /etc/seaweedfs/:\n"+ - " https://github.com/chrislusf/seaweedfs/blob/master/weed/%s.toml\n"+ - "\nOr use this command to generate the default toml file\n"+ + "\n\nPlease use this command to generate the default %s.toml file\n"+ " weed scaffold -config=%s -output=.\n\n\n", configFileName, configFileName, configFileName) } else { @@ -41,3 +40,11 @@ func LoadConfiguration(configFileName string, required bool) (loaded bool) { return true } + +func GetViper() *viper.Viper { + v := viper.GetViper() + v.AutomaticEnv() + v.SetEnvPrefix("weed") + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + return v +} diff --git a/weed/util/constants.go b/weed/util/constants.go index f0df5fd59..c23bc11f6 100644 --- a/weed/util/constants.go +++ b/weed/util/constants.go @@ -5,5 +5,5 @@ import ( ) var ( - VERSION = fmt.Sprintf("%s %d.%d", sizeLimit, 1, 45) + VERSION = fmt.Sprintf("%s %d.%d", sizeLimit, 1, 61) ) diff --git a/weed/util/file_util.go b/weed/util/file_util.go index 78add6724..bef9f7cd6 100644 --- a/weed/util/file_util.go +++ b/weed/util/file_util.go @@ -3,6 +3,7 @@ package util import ( "errors" "os" + "time" "github.com/chrislusf/seaweedfs/weed/glog" ) @@ -40,3 +41,21 @@ func FileExists(filename string) bool { return true } + +func CheckFile(filename string) (exists, canRead, canWrite bool, modTime time.Time, fileSize int64) { + exists = true + fi, err := os.Stat(filename) + if os.IsNotExist(err) { + exists = false + return + } + if fi.Mode()&0400 != 0 { + canRead = true + } + if fi.Mode()&0200 != 0 { + canWrite = true + } + modTime = fi.ModTime() + fileSize = fi.Size() + return +} diff --git a/weed/util/grpc_client_server.go b/weed/util/grpc_client_server.go deleted file mode 100644 index 31497ad35..000000000 --- a/weed/util/grpc_client_server.go +++ /dev/null @@ -1,120 +0,0 @@ -package util - -import ( - "context" - "fmt" - "net/http" - "strconv" - "strings" - "sync" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" -) - -var ( - // cache grpc connections - grpcClients = make(map[string]*grpc.ClientConn) - grpcClientsLock sync.Mutex -) - -func init() { - http.DefaultTransport.(*http.Transport).MaxIdleConnsPerHost = 1024 -} - -func NewGrpcServer(opts ...grpc.ServerOption) *grpc.Server { - var options []grpc.ServerOption - options = append(options, grpc.KeepaliveParams(keepalive.ServerParameters{ - Time: 10 * time.Second, // wait time before ping if no activity - Timeout: 20 * time.Second, // ping timeout - }), grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ - MinTime: 60 * time.Second, // min time a client should wait before sending a ping - })) - for _, opt := range opts { - if opt != nil { - options = append(options, opt) - } - } - return grpc.NewServer(options...) -} - -func GrpcDial(ctx context.Context, address string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - // opts = append(opts, grpc.WithBlock()) - // opts = append(opts, grpc.WithTimeout(time.Duration(5*time.Second))) - var options []grpc.DialOption - options = append(options, - // grpc.WithInsecure(), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 30 * time.Second, // client ping server if no activity for this long - Timeout: 20 * time.Second, - })) - for _, opt := range opts { - if opt != nil { - options = append(options, opt) - } - } - return grpc.DialContext(ctx, address, options...) -} - -func WithCachedGrpcClient(ctx context.Context, fn func(*grpc.ClientConn) error, address string, opts ...grpc.DialOption) error { - - grpcClientsLock.Lock() - - existingConnection, found := grpcClients[address] - if found { - grpcClientsLock.Unlock() - return fn(existingConnection) - } - - grpcConnection, err := GrpcDial(ctx, address, opts...) - if err != nil { - grpcClientsLock.Unlock() - return fmt.Errorf("fail to dial %s: %v", address, err) - } - - grpcClients[address] = grpcConnection - grpcClientsLock.Unlock() - - err = fn(grpcConnection) - if err != nil { - grpcClientsLock.Lock() - delete(grpcClients, address) - grpcClientsLock.Unlock() - grpcConnection.Close() - } - - return err -} - -func ParseServerToGrpcAddress(server string) (serverGrpcAddress string, err error) { - colonIndex := strings.LastIndex(server, ":") - if colonIndex < 0 { - return "", fmt.Errorf("server should have hostname:port format: %v", server) - } - - port, parseErr := strconv.ParseUint(server[colonIndex+1:], 10, 64) - if parseErr != nil { - return "", fmt.Errorf("server port parse error: %v", parseErr) - } - - grpcPort := int(port) + 10000 - - return fmt.Sprintf("%s:%d", server[:colonIndex], grpcPort), nil -} - -func ServerToGrpcAddress(server string) (serverGrpcAddress string) { - hostnameAndPort := strings.Split(server, ":") - if len(hostnameAndPort) != 2 { - return fmt.Sprintf("unexpected server address: %s", server) - } - - port, parseErr := strconv.ParseUint(hostnameAndPort[1], 10, 64) - if parseErr != nil { - return fmt.Sprintf("failed to parse port for %s:%s", hostnameAndPort[0], hostnameAndPort[1]) - } - - grpcPort := int(port) + 10000 - - return fmt.Sprintf("%s:%d", hostnameAndPort[0], grpcPort) -} diff --git a/weed/util/http_util.go b/weed/util/http_util.go index 79a442a56..750516b92 100644 --- a/weed/util/http_util.go +++ b/weed/util/http_util.go @@ -11,6 +11,8 @@ import ( "net/http" "net/url" "strings" + + "github.com/chrislusf/seaweedfs/weed/glog" ) var ( @@ -28,18 +30,18 @@ func init() { } func PostBytes(url string, body []byte) ([]byte, error) { - r, err := client.Post(url, "application/octet-stream", bytes.NewReader(body)) + r, err := client.Post(url, "", bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("Post to %s: %v", url, err) } defer r.Body.Close() - if r.StatusCode >= 400 { - return nil, fmt.Errorf("%s: %s", url, r.Status) - } b, err := ioutil.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("Read response body: %v", err) } + if r.StatusCode >= 400 { + return nil, fmt.Errorf("%s: %s", url, r.Status) + } return b, nil } @@ -86,7 +88,7 @@ func Head(url string) (http.Header, error) { if err != nil { return nil, err } - defer r.Body.Close() + defer CloseResponse(r) if r.StatusCode >= 400 { return nil, fmt.Errorf("%s: %s", url, r.Status) } @@ -128,7 +130,7 @@ func GetBufferStream(url string, values url.Values, allocatedBytes []byte, eachB if err != nil { return err } - defer r.Body.Close() + defer CloseResponse(r) if r.StatusCode != 200 { return fmt.Errorf("%s: %s", url, r.Status) } @@ -151,7 +153,7 @@ func GetUrlStream(url string, values url.Values, readFn func(io.Reader) error) e if err != nil { return err } - defer r.Body.Close() + defer CloseResponse(r) if r.StatusCode != 200 { return fmt.Errorf("%s: %s", url, r.Status) } @@ -187,11 +189,22 @@ func NormalizeUrl(url string) string { return "http://" + url } -func ReadUrl(fileUrl string, offset int64, size int, buf []byte, isReadRange bool) (n int64, e error) { +func ReadUrl(fileUrl string, cipherKey []byte, isGzipped bool, isFullChunk bool, offset int64, size int, buf []byte) (int64, error) { + + if cipherKey != nil { + var n int + err := readEncryptedUrl(fileUrl, cipherKey, isGzipped, offset, size, func(data []byte) { + n = copy(buf, data) + }) + return int64(n), err + } - req, _ := http.NewRequest("GET", fileUrl, nil) - if isReadRange { - req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", offset, offset+int64(size))) + req, err := http.NewRequest("GET", fileUrl, nil) + if err != nil { + return 0, err + } + if !isFullChunk { + req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", offset, offset+int64(size)-1)) } else { req.Header.Set("Accept-Encoding", "gzip") } @@ -207,7 +220,8 @@ func ReadUrl(fileUrl string, offset int64, size int, buf []byte, isReadRange boo } var reader io.ReadCloser - switch r.Header.Get("Content-Encoding") { + contentEncoding := r.Header.Get("Content-Encoding") + switch contentEncoding { case "gzip": reader, err = gzip.NewReader(r.Body) defer reader.Close() @@ -215,55 +229,121 @@ func ReadUrl(fileUrl string, offset int64, size int, buf []byte, isReadRange boo reader = r.Body } - var i, m int + var ( + i, m int + n int64 + ) + // refers to https://github.com/golang/go/blob/master/src/bytes/buffer.go#L199 + // commit id c170b14c2c1cfb2fd853a37add92a82fd6eb4318 for { m, err = reader.Read(buf[i:]) - if m == 0 { - return - } i += m n += int64(m) if err == io.EOF { return n, nil } - if e != nil { - return n, e + if err != nil { + return n, err + } + if n == int64(len(buf)) { + break } } - + // drains the response body to avoid memory leak + data, _ := ioutil.ReadAll(reader) + if len(data) != 0 { + glog.V(1).Infof("%s reader has remaining %d bytes", contentEncoding, len(data)) + } + return n, err } -func ReadUrlAsStream(fileUrl string, offset int64, size int, fn func(data []byte)) (n int64, e error) { +func ReadUrlAsStream(fileUrl string, cipherKey []byte, isContentGzipped bool, isFullChunk bool, offset int64, size int, fn func(data []byte)) error { + + if cipherKey != nil { + return readEncryptedUrl(fileUrl, cipherKey, isContentGzipped, offset, size, fn) + } + + req, err := http.NewRequest("GET", fileUrl, nil) + if err != nil { + return err + } - req, _ := http.NewRequest("GET", fileUrl, nil) - req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", offset, offset+int64(size))) + if !isFullChunk { + req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", offset, offset+int64(size)-1)) + } r, err := client.Do(req) if err != nil { - return 0, err + return err } - defer r.Body.Close() + defer CloseResponse(r) if r.StatusCode >= 400 { - return 0, fmt.Errorf("%s: %s", fileUrl, r.Status) + return fmt.Errorf("%s: %s", fileUrl, r.Status) } - var m int + var ( + m int + ) buf := make([]byte, 64*1024) for { m, err = r.Body.Read(buf) - if m == 0 { - return - } fn(buf[:m]) - n += int64(m) if err == io.EOF { - return n, nil + return nil + } + if err != nil { + return err } - if e != nil { - return n, e + } + +} + +func readEncryptedUrl(fileUrl string, cipherKey []byte, isContentGzipped bool, offset int64, size int, fn func(data []byte)) error { + encryptedData, err := Get(fileUrl) + if err != nil { + return fmt.Errorf("fetch %s: %v", fileUrl, err) + } + decryptedData, err := Decrypt(encryptedData, CipherKey(cipherKey)) + if err != nil { + return fmt.Errorf("decrypt %s: %v", fileUrl, err) + } + if isContentGzipped { + decryptedData, err = UnGzipData(decryptedData) + if err != nil { + return fmt.Errorf("unzip decrypt %s: %v", fileUrl, err) } } + if len(decryptedData) < int(offset)+size { + return fmt.Errorf("read decrypted %s size %d [%d, %d)", fileUrl, len(decryptedData), offset, int(offset)+size) + } + fn(decryptedData[int(offset) : int(offset)+size]) + return nil +} + +func ReadUrlAsReaderCloser(fileUrl string, rangeHeader string) (io.ReadCloser, error) { + + req, err := http.NewRequest("GET", fileUrl, nil) + if err != nil { + return nil, err + } + if rangeHeader != "" { + req.Header.Add("Range", rangeHeader) + } + + r, err := client.Do(req) + if err != nil { + return nil, err + } + if r.StatusCode >= 400 { + return nil, fmt.Errorf("%s: %s", fileUrl, r.Status) + } + + return r.Body, nil +} +func CloseResponse(resp *http.Response) { + io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() } diff --git a/weed/util/httpdown/http_down.go b/weed/util/httpdown/http_down.go new file mode 100644 index 000000000..5cbd9611c --- /dev/null +++ b/weed/util/httpdown/http_down.go @@ -0,0 +1,395 @@ +// Package httpdown provides http.ConnState enabled graceful termination of +// http.Server. +// based on github.com/facebookarchive/httpdown, who's licence is MIT-licence, +// we add a feature of supporting for http TLS +package httpdown + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/facebookgo/clock" + "github.com/facebookgo/stats" +) + +const ( + defaultStopTimeout = time.Minute + defaultKillTimeout = time.Minute +) + +// A Server allows encapsulates the process of accepting new connections and +// serving them, and gracefully shutting down the listener without dropping +// active connections. +type Server interface { + // Wait waits for the serving loop to finish. This will happen when Stop is + // called, at which point it returns no error, or if there is an error in the + // serving loop. You must call Wait after calling Serve or ListenAndServe. + Wait() error + + // Stop stops the listener. It will block until all connections have been + // closed. + Stop() error +} + +// HTTP defines the configuration for serving a http.Server. Multiple calls to +// Serve or ListenAndServe can be made on the same HTTP instance. The default +// timeouts of 1 minute each result in a maximum of 2 minutes before a Stop() +// returns. +type HTTP struct { + // StopTimeout is the duration before we begin force closing connections. + // Defaults to 1 minute. + StopTimeout time.Duration + + // KillTimeout is the duration before which we completely give up and abort + // even though we still have connected clients. This is useful when a large + // number of client connections exist and closing them can take a long time. + // Note, this is in addition to the StopTimeout. Defaults to 1 minute. + KillTimeout time.Duration + + // Stats is optional. If provided, it will be used to record various metrics. + Stats stats.Client + + // Clock allows for testing timing related functionality. Do not specify this + // in production code. + Clock clock.Clock + + // when set CertFile and KeyFile, the httpDown will start a http with TLS. + // Files containing a certificate and matching private key for the + // server must be provided if neither the Server's + // TLSConfig.Certificates nor TLSConfig.GetCertificate are populated. + // If the certificate is signed by a certificate authority, the + // certFile should be the concatenation of the server's certificate, + // any intermediates, and the CA's certificate. + CertFile, KeyFile string +} + +// Serve provides the low-level API which is useful if you're creating your own +// net.Listener. +func (h HTTP) Serve(s *http.Server, l net.Listener) Server { + stopTimeout := h.StopTimeout + if stopTimeout == 0 { + stopTimeout = defaultStopTimeout + } + killTimeout := h.KillTimeout + if killTimeout == 0 { + killTimeout = defaultKillTimeout + } + klock := h.Clock + if klock == nil { + klock = clock.New() + } + + ss := &server{ + stopTimeout: stopTimeout, + killTimeout: killTimeout, + stats: h.Stats, + clock: klock, + oldConnState: s.ConnState, + listener: l, + server: s, + serveDone: make(chan struct{}), + serveErr: make(chan error, 1), + new: make(chan net.Conn), + active: make(chan net.Conn), + idle: make(chan net.Conn), + closed: make(chan net.Conn), + stop: make(chan chan struct{}), + kill: make(chan chan struct{}), + certFile: h.CertFile, + keyFile: h.KeyFile, + } + s.ConnState = ss.connState + go ss.manage() + go ss.serve() + return ss +} + +// ListenAndServe returns a Server for the given http.Server. It is equivalent +// to ListenAndServe from the standard library, but returns immediately. +// Requests will be accepted in a background goroutine. If the http.Server has +// a non-nil TLSConfig, a TLS enabled listener will be setup. +func (h HTTP) ListenAndServe(s *http.Server) (Server, error) { + addr := s.Addr + if addr == "" { + if s.TLSConfig == nil { + addr = ":http" + } else { + addr = ":https" + } + } + l, err := net.Listen("tcp", addr) + if err != nil { + stats.BumpSum(h.Stats, "listen.error", 1) + return nil, err + } + if s.TLSConfig != nil { + l = tls.NewListener(l, s.TLSConfig) + } + return h.Serve(s, l), nil +} + +// server manages the serving process and allows for gracefully stopping it. +type server struct { + stopTimeout time.Duration + killTimeout time.Duration + stats stats.Client + clock clock.Clock + + oldConnState func(net.Conn, http.ConnState) + server *http.Server + serveDone chan struct{} + serveErr chan error + listener net.Listener + + new chan net.Conn + active chan net.Conn + idle chan net.Conn + closed chan net.Conn + stop chan chan struct{} + kill chan chan struct{} + + stopOnce sync.Once + stopErr error + + certFile, keyFile string +} + +func (s *server) connState(c net.Conn, cs http.ConnState) { + if s.oldConnState != nil { + s.oldConnState(c, cs) + } + + switch cs { + case http.StateNew: + s.new <- c + case http.StateActive: + s.active <- c + case http.StateIdle: + s.idle <- c + case http.StateHijacked, http.StateClosed: + s.closed <- c + } +} + +func (s *server) manage() { + defer func() { + close(s.new) + close(s.active) + close(s.idle) + close(s.closed) + close(s.stop) + close(s.kill) + }() + + var stopDone chan struct{} + + conns := map[net.Conn]http.ConnState{} + var countNew, countActive, countIdle float64 + + // decConn decrements the count associated with the current state of the + // given connection. + decConn := func(c net.Conn) { + switch conns[c] { + default: + panic(fmt.Errorf("unknown existing connection: %s", c)) + case http.StateNew: + countNew-- + case http.StateActive: + countActive-- + case http.StateIdle: + countIdle-- + } + } + + // setup a ticker to report various values every minute. if we don't have a + // Stats implementation provided, we Stop it so it never ticks. + statsTicker := s.clock.Ticker(time.Minute) + if s.stats == nil { + statsTicker.Stop() + } + + for { + select { + case <-statsTicker.C: + // we'll only get here when s.stats is not nil + s.stats.BumpAvg("http-state.new", countNew) + s.stats.BumpAvg("http-state.active", countActive) + s.stats.BumpAvg("http-state.idle", countIdle) + s.stats.BumpAvg("http-state.total", countNew+countActive+countIdle) + case c := <-s.new: + conns[c] = http.StateNew + countNew++ + case c := <-s.active: + decConn(c) + countActive++ + + conns[c] = http.StateActive + case c := <-s.idle: + decConn(c) + countIdle++ + + conns[c] = http.StateIdle + + // if we're already stopping, close it + if stopDone != nil { + c.Close() + } + case c := <-s.closed: + stats.BumpSum(s.stats, "conn.closed", 1) + decConn(c) + delete(conns, c) + + // if we're waiting to stop and are all empty, we just closed the last + // connection and we're done. + if stopDone != nil && len(conns) == 0 { + close(stopDone) + return + } + case stopDone = <-s.stop: + // if we're already all empty, we're already done + if len(conns) == 0 { + close(stopDone) + return + } + + // close current idle connections right away + for c, cs := range conns { + if cs == http.StateIdle { + c.Close() + } + } + + // continue the loop and wait for all the ConnState updates which will + // eventually close(stopDone) and return from this goroutine. + + case killDone := <-s.kill: + // force close all connections + stats.BumpSum(s.stats, "kill.conn.count", float64(len(conns))) + for c := range conns { + c.Close() + } + + // don't block the kill. + close(killDone) + + // continue the loop and we wait for all the ConnState updates and will + // return from this goroutine when we're all done. otherwise we'll try to + // send those ConnState updates on closed channels. + + } + } +} + +func (s *server) serve() { + stats.BumpSum(s.stats, "serve", 1) + if s.certFile == "" && s.keyFile == "" { + s.serveErr <- s.server.Serve(s.listener) + } else { + s.serveErr <- s.server.ServeTLS(s.listener, s.certFile, s.keyFile) + } + close(s.serveDone) + close(s.serveErr) +} + +func (s *server) Wait() error { + if err := <-s.serveErr; !isUseOfClosedError(err) { + return err + } + return nil +} + +func (s *server) Stop() error { + s.stopOnce.Do(func() { + defer stats.BumpTime(s.stats, "stop.time").End() + stats.BumpSum(s.stats, "stop", 1) + + // first disable keep-alive for new connections + s.server.SetKeepAlivesEnabled(false) + + // then close the listener so new connections can't connect come thru + closeErr := s.listener.Close() + <-s.serveDone + + // then trigger the background goroutine to stop and wait for it + stopDone := make(chan struct{}) + s.stop <- stopDone + + // wait for stop + select { + case <-stopDone: + case <-s.clock.After(s.stopTimeout): + defer stats.BumpTime(s.stats, "kill.time").End() + stats.BumpSum(s.stats, "kill", 1) + + // stop timed out, wait for kill + killDone := make(chan struct{}) + s.kill <- killDone + select { + case <-killDone: + case <-s.clock.After(s.killTimeout): + // kill timed out, give up + stats.BumpSum(s.stats, "kill.timeout", 1) + } + } + + if closeErr != nil && !isUseOfClosedError(closeErr) { + stats.BumpSum(s.stats, "listener.close.error", 1) + s.stopErr = closeErr + } + }) + return s.stopErr +} + +func isUseOfClosedError(err error) bool { + if err == nil { + return false + } + if opErr, ok := err.(*net.OpError); ok { + err = opErr.Err + } + return err.Error() == "use of closed network connection" +} + +// ListenAndServe is a convenience function to serve and wait for a SIGTERM +// or SIGINT before shutting down. +func ListenAndServe(s *http.Server, hd *HTTP) error { + if hd == nil { + hd = &HTTP{} + } + hs, err := hd.ListenAndServe(s) + if err != nil { + return err + } + + waiterr := make(chan error, 1) + go func() { + defer close(waiterr) + waiterr <- hs.Wait() + }() + + signals := make(chan os.Signal, 10) + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) + + select { + case err := <-waiterr: + if err != nil { + return err + } + case <-signals: + signal.Stop(signals) + if err := hs.Stop(); err != nil { + return err + } + if err := <-waiterr; err != nil { + return err + } + } + return nil +} diff --git a/weed/util/net_timeout.go b/weed/util/net_timeout.go index b8068e67f..8acd50d42 100644 --- a/weed/util/net_timeout.go +++ b/weed/util/net_timeout.go @@ -66,11 +66,8 @@ func (c *Conn) Write(b []byte) (count int, e error) { } func (c *Conn) Close() error { - err := c.Conn.Close() - if err == nil { - stats.ConnectionClose() - } - return err + stats.ConnectionClose() + return c.Conn.Close() } func NewListener(addr string, timeout time.Duration) (net.Listener, error) { diff --git a/weed/util/queue.go b/weed/util/queue.go new file mode 100644 index 000000000..1e6211e0d --- /dev/null +++ b/weed/util/queue.go @@ -0,0 +1,61 @@ +package util + +import "sync" + +type node struct { + data interface{} + next *node +} + +type Queue struct { + head *node + tail *node + count int + sync.RWMutex +} + +func NewQueue() *Queue { + q := &Queue{} + return q +} + +func (q *Queue) Len() int { + q.RLock() + defer q.RUnlock() + return q.count +} + +func (q *Queue) Enqueue(item interface{}) { + q.Lock() + defer q.Unlock() + + n := &node{data: item} + + if q.tail == nil { + q.tail = n + q.head = n + } else { + q.tail.next = n + q.tail = n + } + q.count++ +} + +func (q *Queue) Dequeue() interface{} { + q.Lock() + defer q.Unlock() + + if q.head == nil { + return nil + } + + n := q.head + q.head = n.next + + if q.head == nil { + q.tail = nil + } + q.count-- + + return n.data +} diff --git a/weed/util/queue_unbounded.go b/weed/util/queue_unbounded.go new file mode 100644 index 000000000..664cd965e --- /dev/null +++ b/weed/util/queue_unbounded.go @@ -0,0 +1,45 @@ +package util + +import "sync" + +type UnboundedQueue struct { + outbound []string + outboundLock sync.RWMutex + inbound []string + inboundLock sync.RWMutex +} + +func NewUnboundedQueue() *UnboundedQueue { + q := &UnboundedQueue{} + return q +} + +func (q *UnboundedQueue) EnQueue(items ...string) { + q.inboundLock.Lock() + defer q.inboundLock.Unlock() + + q.outbound = append(q.outbound, items...) + +} + +func (q *UnboundedQueue) Consume(fn func([]string)) { + q.outboundLock.Lock() + defer q.outboundLock.Unlock() + + if len(q.outbound) == 0 { + q.inboundLock.Lock() + inbountLen := len(q.inbound) + if inbountLen > 0 { + t := q.outbound + q.outbound = q.inbound + q.inbound = t + } + q.inboundLock.Unlock() + } + + if len(q.outbound) > 0 { + fn(q.outbound) + q.outbound = q.outbound[:0] + } + +} diff --git a/weed/util/queue_unbounded_test.go b/weed/util/queue_unbounded_test.go new file mode 100644 index 000000000..2d02032cb --- /dev/null +++ b/weed/util/queue_unbounded_test.go @@ -0,0 +1,25 @@ +package util + +import "testing" + +func TestEnqueueAndConsume(t *testing.T) { + + q := NewUnboundedQueue() + + q.EnQueue("1", "2", "3") + + f := func(items []string) { + for _, t := range items { + println(t) + } + println("-----------------------") + } + q.Consume(f) + + q.Consume(f) + + q.EnQueue("4", "5") + q.EnQueue("6", "7") + q.Consume(f) + +} |
