diff options
Diffstat (limited to 'weed/security/guard.go')
| -rw-r--r-- | weed/security/guard.go | 69 |
1 files changed, 40 insertions, 29 deletions
diff --git a/weed/security/guard.go b/weed/security/guard.go index 14aacb83c..f92b10044 100644 --- a/weed/security/guard.go +++ b/weed/security/guard.go @@ -3,11 +3,10 @@ package security import ( "errors" "fmt" + "github.com/seaweedfs/seaweedfs/weed/glog" "net" "net/http" "strings" - - "github.com/seaweedfs/seaweedfs/weed/glog" ) var ( @@ -40,24 +39,25 @@ Referenced: https://github.com/pkieltyka/jwtauth/blob/master/jwtauth.go */ type Guard struct { - whiteList []string + whiteListIp map[string]struct{} + whiteListCIDR map[string]*net.IPNet SigningKey SigningKey ExpiresAfterSec int ReadSigningKey SigningKey ReadExpiresAfterSec int - isWriteActive bool + isWriteActive bool + isEmptyWhiteList bool } func NewGuard(whiteList []string, signingKey string, expiresAfterSec int, readSigningKey string, readExpiresAfterSec int) *Guard { g := &Guard{ - whiteList: whiteList, SigningKey: SigningKey(signingKey), ExpiresAfterSec: expiresAfterSec, ReadSigningKey: SigningKey(readSigningKey), ReadExpiresAfterSec: readExpiresAfterSec, } - g.isWriteActive = len(g.whiteList) != 0 || len(g.SigningKey) != 0 + g.UpdateWhiteList(whiteList) return g } @@ -90,37 +90,48 @@ func GetActualRemoteHost(r *http.Request) (host string, err error) { } func (g *Guard) checkWhiteList(w http.ResponseWriter, r *http.Request) error { - if len(g.whiteList) == 0 { + if g.isEmptyWhiteList { return nil } host, err := GetActualRemoteHost(r) - if err == nil { - for _, ip := range g.whiteList { - - // If the whitelist entry contains a "/" it - // is a CIDR range, and we should check the - // remote host is within it - if strings.Contains(ip, "/") { - _, cidrnet, err := net.ParseCIDR(ip) - if err != nil { - panic(err) - } - remote := net.ParseIP(host) - if cidrnet.Contains(remote) { - return nil - } - } + if err != nil { + return fmt.Errorf("get actual remote host %s in checkWhiteList failed: %v", r.RemoteAddr, err) + } - // - // Otherwise we're looking for a literal match. - // - if ip == host { - return nil - } + if _, ok := g.whiteListIp[host]; ok { + return nil + } + + for _, cidrnet := range g.whiteListCIDR { + // If the whitelist entry contains a "/" it + // is a CIDR range, and we should check the + remote := net.ParseIP(host) + if cidrnet.Contains(remote) { + return nil } } glog.V(0).Infof("Not in whitelist: %s", r.RemoteAddr) return fmt.Errorf("Not in whitelist: %s", r.RemoteAddr) } + +func (g *Guard) UpdateWhiteList(whiteList []string) { + whiteListIp := make(map[string]struct{}) + whiteListCIDR := make(map[string]*net.IPNet) + for _, ip := range whiteList { + if strings.Contains(ip, "/") { + _, cidrnet, err := net.ParseCIDR(ip) + if err != nil { + glog.Errorf("Parse CIDR %s in whitelist failed: %v", ip, err) + } + whiteListCIDR[ip] = cidrnet + } else { + whiteListIp[ip] = struct{}{} + } + } + g.isEmptyWhiteList = len(whiteListIp) == 0 && len(whiteListCIDR) == 0 + g.isWriteActive = !g.isEmptyWhiteList || len(g.SigningKey) != 0 + g.whiteListIp = whiteListIp + g.whiteListCIDR = whiteListCIDR +} |
