diff options
| author | Chris Lu <chris.lu@gmail.com> | 2020-02-10 20:23:04 -0800 |
|---|---|---|
| committer | Chris Lu <chris.lu@gmail.com> | 2020-02-10 20:23:04 -0800 |
| commit | ac3fc922566c42fb2d1d5603e7b0c167b868fce7 (patch) | |
| tree | 3451368af70559711e435ba3dbb92c3752ff919c /weed/security/guard.go | |
| parent | 29945fad51320deb7c72f57d1c7a84bcc51429da (diff) | |
| download | seaweedfs-origin/fasthttp.tar.xz seaweedfs-origin/fasthttp.zip | |
partially doneorigin/fasthttp
Diffstat (limited to 'weed/security/guard.go')
| -rw-r--r-- | weed/security/guard.go | 71 |
1 files changed, 70 insertions, 1 deletions
diff --git a/weed/security/guard.go b/weed/security/guard.go index 17fe2ea9e..eb120529e 100644 --- a/weed/security/guard.go +++ b/weed/security/guard.go @@ -1,12 +1,15 @@ package security import ( + "bytes" "errors" "fmt" "net" "net/http" "strings" + "github.com/valyala/fasthttp" + "github.com/chrislusf/seaweedfs/weed/glog" ) @@ -62,7 +65,7 @@ func NewGuard(whiteList []string, signingKey string, expiresAfterSec int, readSi return g } -func (g *Guard) WhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) { +func (g *Guard) OldWhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) { if !g.isWriteActive { //if no security needed, just skip all checking return f @@ -76,6 +79,20 @@ func (g *Guard) WhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w } } +func (g *Guard) WhiteList(f func(ctx *fasthttp.RequestCtx)) func(ctx *fasthttp.RequestCtx) { + if !g.isWriteActive { + //if no security needed, just skip all checking + return f + } + return func(ctx *fasthttp.RequestCtx) { + if err := g.fastCheckWhiteList(ctx); err != nil { + ctx.SetStatusCode(http.StatusUnauthorized) + return + } + f(ctx) + } +} + func GetActualRemoteHost(r *http.Request) (host string, err error) { host = r.Header.Get("HTTP_X_FORWARDED_FOR") if host == "" { @@ -90,6 +107,22 @@ func GetActualRemoteHost(r *http.Request) (host string, err error) { return } +func FastGetActualRemoteHost(ctx *fasthttp.RequestCtx) (theHost string, err error) { + host := ctx.Request.Header.Peek("HTTP_X_FORWARDED_FOR") + if host == nil { + host = ctx.Request.Header.Peek("X-FORWARDED-FOR") + } + commaIndex := bytes.IndexByte(host, ',') + if commaIndex >= 0 { + host = host[0:commaIndex] + } + if host == nil { + shost, _, serr := net.SplitHostPort(ctx.RemoteAddr().String()) + return shost, serr + } + return string(host), nil +} + func (g *Guard) checkWhiteList(w http.ResponseWriter, r *http.Request) error { if len(g.whiteList) == 0 { return nil @@ -125,3 +158,39 @@ func (g *Guard) checkWhiteList(w http.ResponseWriter, r *http.Request) error { glog.V(0).Infof("Not in whitelist: %s", r.RemoteAddr) return fmt.Errorf("Not in whitelis: %s", r.RemoteAddr) } + +func (g *Guard) fastCheckWhiteList(ctx *fasthttp.RequestCtx) error { + if len(g.whiteList) == 0 { + return nil + } + + host, err := FastGetActualRemoteHost(ctx) + if err == nil { + for _, ip := range g.whiteList { + + // If the whitelist entry contains a "/" it + // is a CIDR range, and we should check the + // remote host is within it + if strings.Contains(ip, "/") { + _, cidrnet, err := net.ParseCIDR(ip) + if err != nil { + panic(err) + } + remote := net.ParseIP(host) + if cidrnet.Contains(remote) { + return nil + } + } + + // + // Otherwise we're looking for a literal match. + // + if ip == host { + return nil + } + } + } + + glog.V(0).Infof("Not in whitelist: %s", ctx.RemoteAddr()) + return fmt.Errorf("Not in whitelis: %s", ctx.RemoteAddr()) +} |
