383 lines
8.6 KiB
Go
383 lines
8.6 KiB
Go
package coresvc
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"encoding/gob"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
"wpw-common/internal"
|
|
"wpw-common/internal/http"
|
|
|
|
"wpw-common/internal/vdext"
|
|
"wpw-common/pkg/gen/biz/core"
|
|
"wpw-common/pkg/gen/biz/core/exceptions"
|
|
"wpw-common/pkg/gen/biz/core/structs"
|
|
|
|
"github.com/go-playground/validator/v10"
|
|
"github.com/gofiber/fiber/v2/middleware/session"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
|
|
"go.uber.org/zap"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type scopeFunc = func(db *gorm.DB) *gorm.DB
|
|
|
|
type CoreService struct {
|
|
vd *validator.Validate
|
|
logger *zap.SugaredLogger
|
|
db *gorm.DB
|
|
|
|
secret []byte
|
|
authExpireDur time.Duration
|
|
}
|
|
|
|
func init() {
|
|
gob.Register(&structs.User{})
|
|
}
|
|
|
|
var _ core.CoreService = (*CoreService)(nil)
|
|
|
|
type CoreV1ServiceParams struct {
|
|
Logger *zap.Logger
|
|
DB *gorm.DB
|
|
Secret []byte
|
|
}
|
|
|
|
func NewCoreServiceFromAppContext(appCtx *internal.AppContext) *CoreService {
|
|
var (
|
|
argName string
|
|
db, secret any
|
|
|
|
exist bool
|
|
)
|
|
// db
|
|
db, exist = appCtx.Data["db"]
|
|
if !exist {
|
|
argName = "db"
|
|
goto invalid_arguments
|
|
}
|
|
// secret
|
|
secret, exist = appCtx.Data["secret"]
|
|
if !exist {
|
|
argName = "secret"
|
|
goto invalid_arguments
|
|
}
|
|
return NewCoreService(CoreV1ServiceParams{
|
|
Logger: appCtx.Logger,
|
|
DB: db.(*gorm.DB),
|
|
Secret: []byte(secret.(string)),
|
|
})
|
|
|
|
invalid_arguments:
|
|
appCtx.Logger.Sugar().Fatal("core v1 service requires `%s` on application context", argName)
|
|
return nil
|
|
}
|
|
|
|
func NewCoreService(params CoreV1ServiceParams) *CoreService {
|
|
c := &CoreService{
|
|
vd: vdext.New("core_v1_dto"),
|
|
logger: params.Logger.Sugar().Named("core-v1"),
|
|
db: params.DB,
|
|
secret: params.Secret,
|
|
authExpireDur: time.Hour * 5,
|
|
}
|
|
|
|
token := jwt.New(jwt.SigningMethodEdDSA)
|
|
claims := token.Claims.(jwt.MapClaims)
|
|
_ = claims
|
|
|
|
c.migrateSchemeAndData()
|
|
return c
|
|
}
|
|
|
|
func (c *CoreService) getEd25519KeyPair() (pri ed25519.PrivateKey, pub ed25519.PublicKey, err error) {
|
|
pub, pri, err = ed25519.GenerateKey(rand.Reader)
|
|
return
|
|
}
|
|
|
|
func (c *CoreService) generateJwtTokenForUser(user *structs.User, jti string) (tokenStr string, err error) {
|
|
// user.PrivateKey
|
|
token := jwt.New(jwt.SigningMethodEdDSA)
|
|
claims := token.Claims.(jwt.MapClaims)
|
|
claims["exp"] = time.Now().Add(c.authExpireDur)
|
|
claims["iss"] = "core-v1"
|
|
claims["uid"] = user.ID
|
|
claims["jti"] = jti
|
|
claims["username"] = user.Username
|
|
var roleIds []int64
|
|
for _, role := range user.Roles {
|
|
roleIds = append(roleIds, role.ID)
|
|
}
|
|
claims["roles"] = roleIds
|
|
|
|
tokenStr, err = token.SignedString(ed25519.PrivateKey(user.PrivateKey))
|
|
return
|
|
|
|
}
|
|
|
|
func (c *CoreService) stripConfidentialUserData(u *structs.User) {
|
|
u.Password = ""
|
|
u.PrivateKey = nil
|
|
u.PublicKey = nil
|
|
}
|
|
|
|
func (c *CoreService) getSession(ctx context.Context) (sess *session.Session, err error) {
|
|
fctx := http.GetFiberFromContext(ctx)
|
|
sess, err = http.Session.Get(fctx)
|
|
if err != nil {
|
|
c.logger.Error("get session failed: ", err)
|
|
exc := exceptions.NewCoreServicesException()
|
|
exc.Code = 500
|
|
exc.Message = "unable to get session"
|
|
exc.Parameters = map[string]string{
|
|
"inner": err.Error(),
|
|
}
|
|
err = exc
|
|
return
|
|
}
|
|
if authToken := fctx.Request().Header.Peek("X-Hub-Auth-Token"); authToken != nil {
|
|
// tbd
|
|
}
|
|
return
|
|
}
|
|
|
|
var (
|
|
ErrBadRequest = errors.New("bad request")
|
|
ErrForbidden = errors.New("forbidden")
|
|
ErrNotAuthenticated = errors.New("not authenticated")
|
|
ErrInvalidState = errors.New("invalid state")
|
|
)
|
|
|
|
func (c *CoreService) checkSession(sess *session.Session) (user *structs.User, err error) {
|
|
userVal := sess.Get("user")
|
|
if userVal == nil {
|
|
err = ErrNotAuthenticated
|
|
return
|
|
}
|
|
var ok bool
|
|
user, ok = userVal.(*structs.User)
|
|
if !ok {
|
|
err = ErrNotAuthenticated
|
|
return
|
|
}
|
|
tx := c.db.Where("id = ?", user.ID).First(user)
|
|
if err = tx.Error; err != nil {
|
|
// must check if the user is still valid / no db error
|
|
// or we just revoke the session.
|
|
if err = sess.Destroy(); err != nil {
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
// TODO: also use the JWT ?
|
|
return
|
|
}
|
|
|
|
func (c *CoreService) getAndCheckSession(ctx context.Context) (user *structs.User, err error) {
|
|
var sess *session.Session
|
|
if sess, err = c.getSession(ctx); err != nil {
|
|
return
|
|
}
|
|
if user, err = c.checkSession(sess); err != nil {
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *CoreService) getAndCheckSessionAdmin(ctx context.Context) (user *structs.User, err error) {
|
|
if user, err = c.getAndCheckSession(ctx); err != nil {
|
|
return
|
|
}
|
|
// user with type system
|
|
if !c.isSystemUser(user) {
|
|
err = ErrForbidden
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *CoreService) isSystemUser(user *structs.User) bool {
|
|
for _, role := range user.Roles {
|
|
if role.RoleType == structs.RoleType_SYSTEM {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *CoreService) alert2Str(alert *structs.AlertInfo) (s string) {
|
|
var b []byte
|
|
var err error
|
|
if b, err = json.Marshal(alert); err != nil {
|
|
s = "{}"
|
|
return
|
|
}
|
|
s = string(b)
|
|
return
|
|
}
|
|
|
|
func (c *CoreService) wrapServiceError(err error, customMessage ...string) error {
|
|
if errVal, ok := err.(*exceptions.CoreServicesException); ok {
|
|
return errVal
|
|
}
|
|
exc := exceptions.NewCoreServicesException()
|
|
exc.Message = err.Error()
|
|
switch err {
|
|
case gorm.ErrRecordNotFound:
|
|
exc.Code = 404
|
|
alert := &structs.AlertInfo{
|
|
Title: "Resource",
|
|
Description: "Resource not found",
|
|
}
|
|
if len(customMessage) >= 2 {
|
|
alert.Title = customMessage[0]
|
|
alert.Description = customMessage[1]
|
|
} else if len(customMessage) >= 1 {
|
|
alert.Description = customMessage[0]
|
|
}
|
|
exc.Parameters = map[string]string{
|
|
"alert": c.alert2Str(alert),
|
|
}
|
|
case ErrForbidden:
|
|
exc.Code = 403
|
|
exc.Parameters = map[string]string{
|
|
"alert": c.alert2Str(&structs.AlertInfo{
|
|
Title: "Access Error",
|
|
Description: err.Error(),
|
|
}),
|
|
}
|
|
case ErrBadRequest:
|
|
exc.Code = 400
|
|
exc.Parameters = map[string]string{
|
|
"alert": c.alert2Str(&structs.AlertInfo{
|
|
Title: "Input Error",
|
|
Description: err.Error(),
|
|
}),
|
|
}
|
|
case ErrNotAuthenticated:
|
|
exc.Code = 401
|
|
exc.Parameters = map[string]string{
|
|
"alert": c.alert2Str(&structs.AlertInfo{
|
|
Title: "Auth Error",
|
|
Description: err.Error(),
|
|
}),
|
|
}
|
|
case ErrInvalidState:
|
|
exc.Code = 500
|
|
exc.Parameters = map[string]string{
|
|
"redirect": "/",
|
|
"alert": c.alert2Str(&structs.AlertInfo{
|
|
Title: "Server Error",
|
|
Description: err.Error(),
|
|
}),
|
|
}
|
|
default:
|
|
exc.Code = 500
|
|
exc.Parameters = map[string]string{
|
|
"alert": c.alert2Str(&structs.AlertInfo{
|
|
Title: "Server Error",
|
|
Description: err.Error(),
|
|
}),
|
|
}
|
|
}
|
|
return exc
|
|
}
|
|
|
|
func (c *CoreService) checkRequest(ctx context.Context, v any) (err error) {
|
|
err = c.vd.StructCtx(ctx, v)
|
|
if err != nil {
|
|
exc := exceptions.NewCoreServicesException()
|
|
exc.Code = 400
|
|
exc.Message = err.Error()
|
|
exc.Parameters = map[string]string{
|
|
"alert": c.alert2Str(&structs.AlertInfo{
|
|
Title: "Validation Error",
|
|
Description: c.err2str(err),
|
|
}),
|
|
}
|
|
err = exc
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
func (c *CoreService) passwordHasher(plain string) string {
|
|
return passwordHasher(plain, c.secret)
|
|
}
|
|
|
|
func (c *CoreService) err2str(err error) (s string) {
|
|
var ve validator.ValidationErrors
|
|
if errors.As(err, &ve) {
|
|
var sb strings.Builder
|
|
for i, f := range ve {
|
|
sb.WriteString(f.Field())
|
|
sb.Write([]byte(` `))
|
|
sb.WriteString(f.Tag())
|
|
if i+1 < len(ve) {
|
|
sb.Write([]byte(`, `))
|
|
}
|
|
}
|
|
return sb.String()
|
|
}
|
|
return err.Error()
|
|
}
|
|
|
|
// pagination2Offset correct pagination and return the offset
|
|
func (c *CoreService) pagination2Offset(pagination *structs.Pagination) (offset int, limit int) {
|
|
offset, limit = 0, 25
|
|
if pagination == nil {
|
|
return
|
|
}
|
|
if pagination.Page <= 0 {
|
|
pagination.Page = 1
|
|
}
|
|
switch {
|
|
case pagination.RowsPerPage > 100:
|
|
pagination.RowsPerPage = 100
|
|
case pagination.RowsPerPage <= 0:
|
|
pagination.RowsPerPage = int64(limit)
|
|
}
|
|
limit = int(pagination.RowsPerPage)
|
|
offset = int(pagination.Page-1) * limit
|
|
return
|
|
}
|
|
|
|
// ScPaginate gorm scope function
|
|
func (c *CoreService) ScPaginate(pagination *structs.Pagination) scopeFunc {
|
|
return func(db *gorm.DB) *gorm.DB {
|
|
offset, limit := c.pagination2Offset(pagination)
|
|
// use id > ? next time.
|
|
return db.Limit(limit).Offset(offset)
|
|
}
|
|
}
|
|
|
|
// ScSearchTerm gorm scope function
|
|
func (c *CoreService) ScSearchTerm(searchTerm string, fields ...string) scopeFunc {
|
|
// bloat FTS-like search :(
|
|
var sb strings.Builder
|
|
var values []any
|
|
|
|
if len(fields) <= 0 {
|
|
goto ret
|
|
}
|
|
for i, field := range fields {
|
|
fmt.Fprintf(&sb, "%s LIKE ?", field)
|
|
if i < len(fields)-1 {
|
|
sb.WriteString(" OR ")
|
|
}
|
|
values = append(values, fmt.Sprintf("%%%s%%", searchTerm))
|
|
}
|
|
|
|
ret:
|
|
return func(db *gorm.DB) *gorm.DB {
|
|
return db.Where(sb.String(), values...)
|
|
}
|
|
}
|