aboutsummaryrefslogtreecommitdiff
path: root/weed/security
diff options
context:
space:
mode:
Diffstat (limited to 'weed/security')
-rw-r--r--weed/security/guard.go71
-rw-r--r--weed/security/jwt.go23
2 files changed, 91 insertions, 3 deletions
diff --git a/weed/security/guard.go b/weed/security/guard.go
index 17fe2ea9e..eb120529e 100644
--- a/weed/security/guard.go
+++ b/weed/security/guard.go
@@ -1,12 +1,15 @@
package security
import (
+ "bytes"
"errors"
"fmt"
"net"
"net/http"
"strings"
+ "github.com/valyala/fasthttp"
+
"github.com/chrislusf/seaweedfs/weed/glog"
)
@@ -62,7 +65,7 @@ func NewGuard(whiteList []string, signingKey string, expiresAfterSec int, readSi
return g
}
-func (g *Guard) WhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
+func (g *Guard) OldWhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
if !g.isWriteActive {
//if no security needed, just skip all checking
return f
@@ -76,6 +79,20 @@ func (g *Guard) WhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w
}
}
+func (g *Guard) WhiteList(f func(ctx *fasthttp.RequestCtx)) func(ctx *fasthttp.RequestCtx) {
+ if !g.isWriteActive {
+ //if no security needed, just skip all checking
+ return f
+ }
+ return func(ctx *fasthttp.RequestCtx) {
+ if err := g.fastCheckWhiteList(ctx); err != nil {
+ ctx.SetStatusCode(http.StatusUnauthorized)
+ return
+ }
+ f(ctx)
+ }
+}
+
func GetActualRemoteHost(r *http.Request) (host string, err error) {
host = r.Header.Get("HTTP_X_FORWARDED_FOR")
if host == "" {
@@ -90,6 +107,22 @@ func GetActualRemoteHost(r *http.Request) (host string, err error) {
return
}
+func FastGetActualRemoteHost(ctx *fasthttp.RequestCtx) (theHost string, err error) {
+ host := ctx.Request.Header.Peek("HTTP_X_FORWARDED_FOR")
+ if host == nil {
+ host = ctx.Request.Header.Peek("X-FORWARDED-FOR")
+ }
+ commaIndex := bytes.IndexByte(host, ',')
+ if commaIndex >= 0 {
+ host = host[0:commaIndex]
+ }
+ if host == nil {
+ shost, _, serr := net.SplitHostPort(ctx.RemoteAddr().String())
+ return shost, serr
+ }
+ return string(host), nil
+}
+
func (g *Guard) checkWhiteList(w http.ResponseWriter, r *http.Request) error {
if len(g.whiteList) == 0 {
return nil
@@ -125,3 +158,39 @@ func (g *Guard) checkWhiteList(w http.ResponseWriter, r *http.Request) error {
glog.V(0).Infof("Not in whitelist: %s", r.RemoteAddr)
return fmt.Errorf("Not in whitelis: %s", r.RemoteAddr)
}
+
+func (g *Guard) fastCheckWhiteList(ctx *fasthttp.RequestCtx) error {
+ if len(g.whiteList) == 0 {
+ return nil
+ }
+
+ host, err := FastGetActualRemoteHost(ctx)
+ if err == nil {
+ for _, ip := range g.whiteList {
+
+ // If the whitelist entry contains a "/" it
+ // is a CIDR range, and we should check the
+ // remote host is within it
+ if strings.Contains(ip, "/") {
+ _, cidrnet, err := net.ParseCIDR(ip)
+ if err != nil {
+ panic(err)
+ }
+ remote := net.ParseIP(host)
+ if cidrnet.Contains(remote) {
+ return nil
+ }
+ }
+
+ //
+ // Otherwise we're looking for a literal match.
+ //
+ if ip == host {
+ return nil
+ }
+ }
+ }
+
+ glog.V(0).Infof("Not in whitelist: %s", ctx.RemoteAddr())
+ return fmt.Errorf("Not in whitelis: %s", ctx.RemoteAddr())
+}
diff --git a/weed/security/jwt.go b/weed/security/jwt.go
index 0bd7fa974..c6da5c7aa 100644
--- a/weed/security/jwt.go
+++ b/weed/security/jwt.go
@@ -1,13 +1,16 @@
package security
import (
+ "bytes"
"fmt"
"net/http"
"strings"
"time"
+ "github.com/dgrijalva/jwt-go"
+ "github.com/valyala/fasthttp"
+
"github.com/chrislusf/seaweedfs/weed/glog"
- jwt "github.com/dgrijalva/jwt-go"
)
type EncodedJwt string
@@ -39,7 +42,7 @@ func GenJwt(signingKey SigningKey, expiresAfterSec int, fileId string) EncodedJw
return EncodedJwt(encoded)
}
-func GetJwt(r *http.Request) EncodedJwt {
+func OldGetJwt(r *http.Request) EncodedJwt {
// Get token from query params
tokenStr := r.URL.Query().Get("jwt")
@@ -55,6 +58,22 @@ func GetJwt(r *http.Request) EncodedJwt {
return EncodedJwt(tokenStr)
}
+func GetJwt(ctx *fasthttp.RequestCtx) EncodedJwt {
+
+ // Get token from query params
+ tokenStr := ctx.FormValue("jwt")
+
+ // Get token from authorization header
+ if tokenStr == nil {
+ bearer := ctx.Request.Header.Peek("Authorization")
+ if len(bearer) > 7 && string(bytes.ToUpper(bearer[0:6])) == "BEARER" {
+ tokenStr = bearer[7:]
+ }
+ }
+
+ return EncodedJwt(tokenStr)
+}
+
func DecodeJwt(signingKey SigningKey, tokenString EncodedJwt) (token *jwt.Token, err error) {
// check exp, nbf
return jwt.ParseWithClaims(string(tokenString), &SeaweedFileIdClaims{}, func(token *jwt.Token) (interface{}, error) {