aboutsummaryrefslogtreecommitdiff
path: root/weed/security
diff options
context:
space:
mode:
Diffstat (limited to 'weed/security')
-rw-r--r--weed/security/guard.go56
-rw-r--r--weed/security/jwt.go55
-rw-r--r--weed/security/tls.go66
3 files changed, 95 insertions, 82 deletions
diff --git a/weed/security/guard.go b/weed/security/guard.go
index fd9c8e0b3..84a415253 100644
--- a/weed/security/guard.go
+++ b/weed/security/guard.go
@@ -41,21 +41,21 @@ https://github.com/pkieltyka/jwtauth/blob/master/jwtauth.go
*/
type Guard struct {
- whiteList []string
- SecretKey Secret
+ whiteList []string
+ SigningKey SigningKey
isActive bool
}
-func NewGuard(whiteList []string, secretKey string) *Guard {
- g := &Guard{whiteList: whiteList, SecretKey: Secret(secretKey)}
- g.isActive = len(g.whiteList) != 0 || len(g.SecretKey) != 0
+func NewGuard(whiteList []string, signingKey string) *Guard {
+ g := &Guard{whiteList: whiteList, SigningKey: SigningKey(signingKey)}
+ g.isActive = len(g.whiteList) != 0 || len(g.SigningKey) != 0
return g
}
func (g *Guard) WhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
if !g.isActive {
- //if no security needed, just skip all checkings
+ //if no security needed, just skip all checking
return f
}
return func(w http.ResponseWriter, r *http.Request) {
@@ -67,20 +67,6 @@ func (g *Guard) WhiteList(f func(w http.ResponseWriter, r *http.Request)) func(w
}
}
-func (g *Guard) Secure(f func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
- if !g.isActive {
- //if no security needed, just skip all checkings
- return f
- }
- return func(w http.ResponseWriter, r *http.Request) {
- if err := g.checkJwt(w, r); err != nil {
- w.WriteHeader(http.StatusUnauthorized)
- return
- }
- f(w, r)
- }
-}
-
func GetActualRemoteHost(r *http.Request) (host string, err error) {
host = r.Header.Get("HTTP_X_FORWARDED_FOR")
if host == "" {
@@ -130,33 +116,3 @@ func (g *Guard) checkWhiteList(w http.ResponseWriter, r *http.Request) error {
glog.V(0).Infof("Not in whitelist: %s", r.RemoteAddr)
return fmt.Errorf("Not in whitelis: %s", r.RemoteAddr)
}
-
-func (g *Guard) checkJwt(w http.ResponseWriter, r *http.Request) error {
- if g.checkWhiteList(w, r) == nil {
- return nil
- }
-
- if len(g.SecretKey) == 0 {
- return nil
- }
-
- tokenStr := GetJwt(r)
-
- if tokenStr == "" {
- return ErrUnauthorized
- }
-
- // Verify the token
- token, err := DecodeJwt(g.SecretKey, tokenStr)
- if err != nil {
- glog.V(1).Infof("Token verification error from %s: %v", r.RemoteAddr, err)
- return ErrUnauthorized
- }
- if !token.Valid {
- glog.V(1).Infof("Token invliad from %s: %v", r.RemoteAddr, tokenStr)
- return ErrUnauthorized
- }
-
- glog.V(1).Infof("No permission from %s", r.RemoteAddr)
- return fmt.Errorf("No write permission from %s", r.RemoteAddr)
-}
diff --git a/weed/security/jwt.go b/weed/security/jwt.go
index 46b7efaaf..45a77f093 100644
--- a/weed/security/jwt.go
+++ b/weed/security/jwt.go
@@ -1,9 +1,9 @@
package security
import (
+ "fmt"
"net/http"
"strings"
-
"time"
"github.com/chrislusf/seaweedfs/weed/glog"
@@ -11,21 +11,28 @@ import (
)
type EncodedJwt string
-type Secret string
+type SigningKey []byte
+
+type SeaweedFileIdClaims struct {
+ Fid string `json:"fid"`
+ jwt.StandardClaims
+}
-func GenJwt(secret Secret, fileId string) EncodedJwt {
- if secret == "" {
+func GenJwt(signingKey SigningKey, fileId string) EncodedJwt {
+ if len(signingKey) == 0 {
return ""
}
- t := jwt.New(jwt.GetSigningMethod("HS256"))
- t.Claims = &jwt.StandardClaims{
- ExpiresAt: time.Now().Add(time.Second * 10).Unix(),
- Subject: fileId,
+ claims := SeaweedFileIdClaims{
+ fileId,
+ jwt.StandardClaims{
+ ExpiresAt: time.Now().Add(time.Second * 10).Unix(),
+ },
}
- encoded, e := t.SignedString(secret)
+ t := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ encoded, e := t.SignedString([]byte(signingKey))
if e != nil {
- glog.V(0).Infof("Failed to sign claims: %v", t.Claims)
+ glog.V(0).Infof("Failed to sign claims %+v: %v", t.Claims, e)
return ""
}
return EncodedJwt(encoded)
@@ -44,31 +51,15 @@ func GetJwt(r *http.Request) EncodedJwt {
}
}
- // Get token from cookie
- if tokenStr == "" {
- cookie, err := r.Cookie("jwt")
- if err == nil {
- tokenStr = cookie.Value
- }
- }
-
return EncodedJwt(tokenStr)
}
-func EncodeJwt(secret Secret, claims *jwt.StandardClaims) (EncodedJwt, error) {
- if secret == "" {
- return "", nil
- }
-
- t := jwt.New(jwt.GetSigningMethod("HS256"))
- t.Claims = claims
- encoded, e := t.SignedString(secret)
- return EncodedJwt(encoded), e
-}
-
-func DecodeJwt(secret Secret, tokenString EncodedJwt) (token *jwt.Token, err error) {
+func DecodeJwt(signingKey SigningKey, tokenString EncodedJwt) (token *jwt.Token, err error) {
// check exp, nbf
- return jwt.Parse(string(tokenString), func(token *jwt.Token) (interface{}, error) {
- return secret, nil
+ return jwt.ParseWithClaims(string(tokenString), &SeaweedFileIdClaims{}, func(token *jwt.Token) (interface{}, error) {
+ if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ return nil, fmt.Errorf("unknown token method")
+ }
+ return []byte(signingKey), nil
})
}
diff --git a/weed/security/tls.go b/weed/security/tls.go
new file mode 100644
index 000000000..e81ba4831
--- /dev/null
+++ b/weed/security/tls.go
@@ -0,0 +1,66 @@
+package security
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "github.com/spf13/viper"
+ "io/ioutil"
+
+ "github.com/chrislusf/seaweedfs/weed/glog"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+)
+
+func LoadServerTLS(config *viper.Viper, component string) grpc.ServerOption {
+ if config == nil {
+ return nil
+ }
+
+ // load cert/key, ca cert
+ cert, err := tls.LoadX509KeyPair(config.GetString(component+".cert"), config.GetString(component+".key"))
+ if err != nil {
+ glog.Errorf("load cert/key error: %v", err)
+ return nil
+ }
+ caCert, err := ioutil.ReadFile(config.GetString("ca"))
+ if err != nil {
+ glog.Errorf("read ca cert file error: %v", err)
+ return nil
+ }
+ caCertPool := x509.NewCertPool()
+ caCertPool.AppendCertsFromPEM(caCert)
+ ta := credentials.NewTLS(&tls.Config{
+ Certificates: []tls.Certificate{cert},
+ ClientCAs: caCertPool,
+ ClientAuth: tls.RequireAndVerifyClientCert,
+ })
+
+ return grpc.Creds(ta)
+}
+
+func LoadClientTLS(config *viper.Viper, component string) grpc.DialOption {
+ if config == nil {
+ return grpc.WithInsecure()
+ }
+
+ // load cert/key, cacert
+ cert, err := tls.LoadX509KeyPair(config.GetString(component+".cert"), config.GetString(component+".key"))
+ if err != nil {
+ glog.Errorf("load cert/key error: %v", err)
+ return grpc.WithInsecure()
+ }
+ caCert, err := ioutil.ReadFile(config.GetString("ca"))
+ if err != nil {
+ glog.Errorf("read ca cert file error: %v", err)
+ return grpc.WithInsecure()
+ }
+ caCertPool := x509.NewCertPool()
+ caCertPool.AppendCertsFromPEM(caCert)
+
+ ta := credentials.NewTLS(&tls.Config{
+ Certificates: []tls.Certificate{cert},
+ RootCAs: caCertPool,
+ InsecureSkipVerify: true,
+ })
+ return grpc.WithTransportCredentials(ta)
+}