You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

98 lines
2.4 KiB

package auth
import (
"context"
"errors"
"fmt"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"github.com/golang-jwt/jwt/v4"
"strings"
"time"
)
var currentUserKey struct{}
type CurrentUser struct {
UserID uint
}
// GenerateToken
//
// @Description:
// @param secret
// @param userid
// @return string
func GenerateToken(secret string, userid uint) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"userid": userid,
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
})
// Sign and get the complete encoded token as a string using the secret
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
panic(err)
}
return tokenString
}
// JWTAuth
//
// @Description:
// @param secret
// @return middleware.Middleware
func JWTAuth(secret string) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromServerContext(ctx); ok {
tokenString := tr.RequestHeader().Get("Authorization")
auths := strings.SplitN(tokenString, " ", 2)
if len(auths) != 2 || !strings.EqualFold(auths[0], "Token") {
return nil, errors.New("jwt token missing")
}
token, err := jwt.Parse(auths[1], func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
return []byte(secret), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// put CurrentUser into ctx
if u, ok := claims["userid"]; ok {
ctx = WithContext(ctx, &CurrentUser{UserID: uint(u.(float64))})
}
} else {
return nil, errors.New("Token Invalid")
}
}
return handler(ctx, req)
}
}
}
// FromContext
//
// @Description:
// @param ctx
// @return *CurrentUser
func FromContext(ctx context.Context) *CurrentUser {
return ctx.Value(currentUserKey).(*CurrentUser)
}
// WithContext
//
// @Description:
// @param ctx
// @param user
// @return context.Context
func WithContext(ctx context.Context, user *CurrentUser) context.Context {
return context.WithValue(ctx, currentUserKey, user)
}