aboutsummaryrefslogtreecommitdiff
path: root/weed/security/guard.go
diff options
context:
space:
mode:
authorChris Lu <chris.lu@gmail.com>2020-02-10 20:23:04 -0800
committerChris Lu <chris.lu@gmail.com>2020-02-10 20:23:04 -0800
commitac3fc922566c42fb2d1d5603e7b0c167b868fce7 (patch)
tree3451368af70559711e435ba3dbb92c3752ff919c /weed/security/guard.go
parent29945fad51320deb7c72f57d1c7a84bcc51429da (diff)
downloadseaweedfs-origin/fasthttp.tar.xz
seaweedfs-origin/fasthttp.zip
partially doneorigin/fasthttp
Diffstat (limited to 'weed/security/guard.go')
-rw-r--r--weed/security/guard.go71
1 files changed, 70 insertions, 1 deletions
diff --git a/weed/security/guard.go b/weed/security/guard.go
index 17fe2ea9e..eb120529e 100644
--- a/weed/security/guard.go
+++ b/weed/security/guard.go
@@ -1,12 +1,15 @@
package security
import (
+ "bytes"
"errors"
"fmt"
"net"
"net/http"
"strings"
+ "github.com/valyala/fasthttp"
+
"github.com/chrislusf/seaweedfs/weed/glog"
)
@@ -62,7 +65,7 @@ func NewGuard(whiteList []string, signingKey string, expiresAfterSec int, readSi
return g
}
-func (g *Guard) WhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
+func (g *Guard) OldWhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
if !g.isWriteActive {
//if no security needed, just skip all checking
return f
@@ -76,6 +79,20 @@ func (g *Guard) WhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w
}
}
+func (g *Guard) WhiteList(f func(ctx *fasthttp.RequestCtx)) func(ctx *fasthttp.RequestCtx) {
+ if !g.isWriteActive {
+ //if no security needed, just skip all checking
+ return f
+ }
+ return func(ctx *fasthttp.RequestCtx) {
+ if err := g.fastCheckWhiteList(ctx); err != nil {
+ ctx.SetStatusCode(http.StatusUnauthorized)
+ return
+ }
+ f(ctx)
+ }
+}
+
func GetActualRemoteHost(r *http.Request) (host string, err error) {
host = r.Header.Get("HTTP_X_FORWARDED_FOR")
if host == "" {
@@ -90,6 +107,22 @@ func GetActualRemoteHost(r *http.Request) (host string, err error) {
return
}
+func FastGetActualRemoteHost(ctx *fasthttp.RequestCtx) (theHost string, err error) {
+ host := ctx.Request.Header.Peek("HTTP_X_FORWARDED_FOR")
+ if host == nil {
+ host = ctx.Request.Header.Peek("X-FORWARDED-FOR")
+ }
+ commaIndex := bytes.IndexByte(host, ',')
+ if commaIndex >= 0 {
+ host = host[0:commaIndex]
+ }
+ if host == nil {
+ shost, _, serr := net.SplitHostPort(ctx.RemoteAddr().String())
+ return shost, serr
+ }
+ return string(host), nil
+}
+
func (g *Guard) checkWhiteList(w http.ResponseWriter, r *http.Request) error {
if len(g.whiteList) == 0 {
return nil
@@ -125,3 +158,39 @@ func (g *Guard) checkWhiteList(w http.ResponseWriter, r *http.Request) error {
glog.V(0).Infof("Not in whitelist: %s", r.RemoteAddr)
return fmt.Errorf("Not in whitelis: %s", r.RemoteAddr)
}
+
+func (g *Guard) fastCheckWhiteList(ctx *fasthttp.RequestCtx) error {
+ if len(g.whiteList) == 0 {
+ return nil
+ }
+
+ host, err := FastGetActualRemoteHost(ctx)
+ if err == nil {
+ for _, ip := range g.whiteList {
+
+ // If the whitelist entry contains a "/" it
+ // is a CIDR range, and we should check the
+ // remote host is within it
+ if strings.Contains(ip, "/") {
+ _, cidrnet, err := net.ParseCIDR(ip)
+ if err != nil {
+ panic(err)
+ }
+ remote := net.ParseIP(host)
+ if cidrnet.Contains(remote) {
+ return nil
+ }
+ }
+
+ //
+ // Otherwise we're looking for a literal match.
+ //
+ if ip == host {
+ return nil
+ }
+ }
+ }
+
+ glog.V(0).Infof("Not in whitelist: %s", ctx.RemoteAddr())
+ return fmt.Errorf("Not in whitelis: %s", ctx.RemoteAddr())
+}