You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
99 lines
2.4 KiB
99 lines
2.4 KiB
2 months ago
|
package auth
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||
|
"github.com/go-kratos/kratos/v2/transport"
|
||
|
"github.com/golang-jwt/jwt/v4"
|
||
|
"strings"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
var currentUserKey struct{}
|
||
|
|
||
|
type CurrentUser struct {
|
||
|
UserID uint
|
||
|
}
|
||
|
|
||
|
// GenerateToken
|
||
|
//
|
||
|
// @Description:
|
||
|
// @param secret
|
||
|
// @param userid
|
||
|
// @return string
|
||
|
func GenerateToken(secret string, userid uint) string {
|
||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||
|
"userid": userid,
|
||
|
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
|
||
|
})
|
||
|
|
||
|
// Sign and get the complete encoded token as a string using the secret
|
||
|
tokenString, err := token.SignedString([]byte(secret))
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return tokenString
|
||
|
}
|
||
|
|
||
|
// JWTAuth
|
||
|
//
|
||
|
// @Description:
|
||
|
// @param secret
|
||
|
// @return middleware.Middleware
|
||
|
func JWTAuth(secret string) middleware.Middleware {
|
||
|
return func(handler middleware.Handler) middleware.Handler {
|
||
|
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
|
||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||
|
tokenString := tr.RequestHeader().Get("Authorization")
|
||
|
auths := strings.SplitN(tokenString, " ", 2)
|
||
|
if len(auths) != 2 || !strings.EqualFold(auths[0], "Token") {
|
||
|
return nil, errors.New("jwt token missing")
|
||
|
}
|
||
|
|
||
|
token, err := jwt.Parse(auths[1], func(token *jwt.Token) (interface{}, error) {
|
||
|
// Don't forget to validate the alg is what you expect:
|
||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
|
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||
|
}
|
||
|
return []byte(secret), nil
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||
|
// put CurrentUser into ctx
|
||
|
if u, ok := claims["userid"]; ok {
|
||
|
ctx = WithContext(ctx, &CurrentUser{UserID: uint(u.(float64))})
|
||
|
}
|
||
|
} else {
|
||
|
return nil, errors.New("Token Invalid")
|
||
|
}
|
||
|
}
|
||
|
return handler(ctx, req)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// FromContext
|
||
|
//
|
||
|
// @Description:
|
||
|
// @param ctx
|
||
|
// @return *CurrentUser
|
||
|
func FromContext(ctx context.Context) *CurrentUser {
|
||
|
return ctx.Value(currentUserKey).(*CurrentUser)
|
||
|
}
|
||
|
|
||
|
// WithContext
|
||
|
//
|
||
|
// @Description:
|
||
|
// @param ctx
|
||
|
// @param user
|
||
|
// @return context.Context
|
||
|
func WithContext(ctx context.Context, user *CurrentUser) context.Context {
|
||
|
return context.WithValue(ctx, currentUserKey, user)
|
||
|
}
|