wpw-final/internal/biz/core/core.go
2022-12-02 20:40:23 +07:00

383 lines
8.6 KiB
Go

package coresvc
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"wpw-common/internal"
"wpw-common/internal/http"
"wpw-common/internal/vdext"
"wpw-common/pkg/gen/biz/core"
"wpw-common/pkg/gen/biz/core/exceptions"
"wpw-common/pkg/gen/biz/core/structs"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2/middleware/session"
"github.com/golang-jwt/jwt/v4"
"go.uber.org/zap"
"gorm.io/gorm"
)
type scopeFunc = func(db *gorm.DB) *gorm.DB
type CoreService struct {
vd *validator.Validate
logger *zap.SugaredLogger
db *gorm.DB
secret []byte
authExpireDur time.Duration
}
func init() {
gob.Register(&structs.User{})
}
var _ core.CoreService = (*CoreService)(nil)
type CoreV1ServiceParams struct {
Logger *zap.Logger
DB *gorm.DB
Secret []byte
}
func NewCoreServiceFromAppContext(appCtx *internal.AppContext) *CoreService {
var (
argName string
db, secret any
exist bool
)
// db
db, exist = appCtx.Data["db"]
if !exist {
argName = "db"
goto invalid_arguments
}
// secret
secret, exist = appCtx.Data["secret"]
if !exist {
argName = "secret"
goto invalid_arguments
}
return NewCoreService(CoreV1ServiceParams{
Logger: appCtx.Logger,
DB: db.(*gorm.DB),
Secret: []byte(secret.(string)),
})
invalid_arguments:
appCtx.Logger.Sugar().Fatal("core v1 service requires `%s` on application context", argName)
return nil
}
func NewCoreService(params CoreV1ServiceParams) *CoreService {
c := &CoreService{
vd: vdext.New("core_v1_dto"),
logger: params.Logger.Sugar().Named("core-v1"),
db: params.DB,
secret: params.Secret,
authExpireDur: time.Hour * 5,
}
token := jwt.New(jwt.SigningMethodEdDSA)
claims := token.Claims.(jwt.MapClaims)
_ = claims
c.migrateSchemeAndData()
return c
}
func (c *CoreService) getEd25519KeyPair() (pri ed25519.PrivateKey, pub ed25519.PublicKey, err error) {
pub, pri, err = ed25519.GenerateKey(rand.Reader)
return
}
func (c *CoreService) generateJwtTokenForUser(user *structs.User, jti string) (tokenStr string, err error) {
// user.PrivateKey
token := jwt.New(jwt.SigningMethodEdDSA)
claims := token.Claims.(jwt.MapClaims)
claims["exp"] = time.Now().Add(c.authExpireDur)
claims["iss"] = "core-v1"
claims["uid"] = user.ID
claims["jti"] = jti
claims["username"] = user.Username
var roleIds []int64
for _, role := range user.Roles {
roleIds = append(roleIds, role.ID)
}
claims["roles"] = roleIds
tokenStr, err = token.SignedString(ed25519.PrivateKey(user.PrivateKey))
return
}
func (c *CoreService) stripConfidentialUserData(u *structs.User) {
u.Password = ""
u.PrivateKey = nil
u.PublicKey = nil
}
func (c *CoreService) getSession(ctx context.Context) (sess *session.Session, err error) {
fctx := http.GetFiberFromContext(ctx)
sess, err = http.Session.Get(fctx)
if err != nil {
c.logger.Error("get session failed: ", err)
exc := exceptions.NewCoreServicesException()
exc.Code = 500
exc.Message = "unable to get session"
exc.Parameters = map[string]string{
"inner": err.Error(),
}
err = exc
return
}
if authToken := fctx.Request().Header.Peek("X-Hub-Auth-Token"); authToken != nil {
// tbd
}
return
}
var (
ErrBadRequest = errors.New("bad request")
ErrForbidden = errors.New("forbidden")
ErrNotAuthenticated = errors.New("not authenticated")
ErrInvalidState = errors.New("invalid state")
)
func (c *CoreService) checkSession(sess *session.Session) (user *structs.User, err error) {
userVal := sess.Get("user")
if userVal == nil {
err = ErrNotAuthenticated
return
}
var ok bool
user, ok = userVal.(*structs.User)
if !ok {
err = ErrNotAuthenticated
return
}
tx := c.db.Where("id = ?", user.ID).First(user)
if err = tx.Error; err != nil {
// must check if the user is still valid / no db error
// or we just revoke the session.
if err = sess.Destroy(); err != nil {
return
}
return
}
// TODO: also use the JWT ?
return
}
func (c *CoreService) getAndCheckSession(ctx context.Context) (user *structs.User, err error) {
var sess *session.Session
if sess, err = c.getSession(ctx); err != nil {
return
}
if user, err = c.checkSession(sess); err != nil {
return
}
return
}
func (c *CoreService) getAndCheckSessionAdmin(ctx context.Context) (user *structs.User, err error) {
if user, err = c.getAndCheckSession(ctx); err != nil {
return
}
// user with type system
if !c.isSystemUser(user) {
err = ErrForbidden
return
}
return
}
func (c *CoreService) isSystemUser(user *structs.User) bool {
for _, role := range user.Roles {
if role.RoleType == structs.RoleType_SYSTEM {
return true
}
}
return false
}
func (c *CoreService) alert2Str(alert *structs.AlertInfo) (s string) {
var b []byte
var err error
if b, err = json.Marshal(alert); err != nil {
s = "{}"
return
}
s = string(b)
return
}
func (c *CoreService) wrapServiceError(err error, customMessage ...string) error {
if errVal, ok := err.(*exceptions.CoreServicesException); ok {
return errVal
}
exc := exceptions.NewCoreServicesException()
exc.Message = err.Error()
switch err {
case gorm.ErrRecordNotFound:
exc.Code = 404
alert := &structs.AlertInfo{
Title: "Resource",
Description: "Resource not found",
}
if len(customMessage) >= 2 {
alert.Title = customMessage[0]
alert.Description = customMessage[1]
} else if len(customMessage) >= 1 {
alert.Description = customMessage[0]
}
exc.Parameters = map[string]string{
"alert": c.alert2Str(alert),
}
case ErrForbidden:
exc.Code = 403
exc.Parameters = map[string]string{
"alert": c.alert2Str(&structs.AlertInfo{
Title: "Access Error",
Description: err.Error(),
}),
}
case ErrBadRequest:
exc.Code = 400
exc.Parameters = map[string]string{
"alert": c.alert2Str(&structs.AlertInfo{
Title: "Input Error",
Description: err.Error(),
}),
}
case ErrNotAuthenticated:
exc.Code = 401
exc.Parameters = map[string]string{
"alert": c.alert2Str(&structs.AlertInfo{
Title: "Auth Error",
Description: err.Error(),
}),
}
case ErrInvalidState:
exc.Code = 500
exc.Parameters = map[string]string{
"redirect": "/",
"alert": c.alert2Str(&structs.AlertInfo{
Title: "Server Error",
Description: err.Error(),
}),
}
default:
exc.Code = 500
exc.Parameters = map[string]string{
"alert": c.alert2Str(&structs.AlertInfo{
Title: "Server Error",
Description: err.Error(),
}),
}
}
return exc
}
func (c *CoreService) checkRequest(ctx context.Context, v any) (err error) {
err = c.vd.StructCtx(ctx, v)
if err != nil {
exc := exceptions.NewCoreServicesException()
exc.Code = 400
exc.Message = err.Error()
exc.Parameters = map[string]string{
"alert": c.alert2Str(&structs.AlertInfo{
Title: "Validation Error",
Description: c.err2str(err),
}),
}
err = exc
return
}
return
}
func (c *CoreService) passwordHasher(plain string) string {
return passwordHasher(plain, c.secret)
}
func (c *CoreService) err2str(err error) (s string) {
var ve validator.ValidationErrors
if errors.As(err, &ve) {
var sb strings.Builder
for i, f := range ve {
sb.WriteString(f.Field())
sb.Write([]byte(` `))
sb.WriteString(f.Tag())
if i+1 < len(ve) {
sb.Write([]byte(`, `))
}
}
return sb.String()
}
return err.Error()
}
// pagination2Offset correct pagination and return the offset
func (c *CoreService) pagination2Offset(pagination *structs.Pagination) (offset int, limit int) {
offset, limit = 0, 25
if pagination == nil {
return
}
if pagination.Page <= 0 {
pagination.Page = 1
}
switch {
case pagination.RowsPerPage > 100:
pagination.RowsPerPage = 100
case pagination.RowsPerPage <= 0:
pagination.RowsPerPage = int64(limit)
}
limit = int(pagination.RowsPerPage)
offset = int(pagination.Page-1) * limit
return
}
// ScPaginate gorm scope function
func (c *CoreService) ScPaginate(pagination *structs.Pagination) scopeFunc {
return func(db *gorm.DB) *gorm.DB {
offset, limit := c.pagination2Offset(pagination)
// use id > ? next time.
return db.Limit(limit).Offset(offset)
}
}
// ScSearchTerm gorm scope function
func (c *CoreService) ScSearchTerm(searchTerm string, fields ...string) scopeFunc {
// bloat FTS-like search :(
var sb strings.Builder
var values []any
if len(fields) <= 0 {
goto ret
}
for i, field := range fields {
fmt.Fprintf(&sb, "%s LIKE ?", field)
if i < len(fields)-1 {
sb.WriteString(" OR ")
}
values = append(values, fmt.Sprintf("%%%s%%", searchTerm))
}
ret:
return func(db *gorm.DB) *gorm.DB {
return db.Where(sb.String(), values...)
}
}