package coresvc import ( "context" "crypto/ed25519" "crypto/rand" "encoding/gob" "encoding/json" "errors" "fmt" "strings" "time" "wpw-common/internal" "wpw-common/internal/http" "wpw-common/internal/vdext" "wpw-common/pkg/gen/biz/core" "wpw-common/pkg/gen/biz/core/exceptions" "wpw-common/pkg/gen/biz/core/structs" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2/middleware/session" "github.com/golang-jwt/jwt/v4" "go.uber.org/zap" "gorm.io/gorm" ) type scopeFunc = func(db *gorm.DB) *gorm.DB type CoreService struct { vd *validator.Validate logger *zap.SugaredLogger db *gorm.DB secret []byte authExpireDur time.Duration } func init() { gob.Register(&structs.User{}) } var _ core.CoreService = (*CoreService)(nil) type CoreV1ServiceParams struct { Logger *zap.Logger DB *gorm.DB Secret []byte } func NewCoreServiceFromAppContext(appCtx *internal.AppContext) *CoreService { var ( argName string db, secret any exist bool ) // db db, exist = appCtx.Data["db"] if !exist { argName = "db" goto invalid_arguments } // secret secret, exist = appCtx.Data["secret"] if !exist { argName = "secret" goto invalid_arguments } return NewCoreService(CoreV1ServiceParams{ Logger: appCtx.Logger, DB: db.(*gorm.DB), Secret: []byte(secret.(string)), }) invalid_arguments: appCtx.Logger.Sugar().Fatal("core v1 service requires `%s` on application context", argName) return nil } func NewCoreService(params CoreV1ServiceParams) *CoreService { c := &CoreService{ vd: vdext.New("core_v1_dto"), logger: params.Logger.Sugar().Named("core-v1"), db: params.DB, secret: params.Secret, authExpireDur: time.Hour * 5, } token := jwt.New(jwt.SigningMethodEdDSA) claims := token.Claims.(jwt.MapClaims) _ = claims c.migrateSchemeAndData() return c } func (c *CoreService) getEd25519KeyPair() (pri ed25519.PrivateKey, pub ed25519.PublicKey, err error) { pub, pri, err = ed25519.GenerateKey(rand.Reader) return } func (c *CoreService) generateJwtTokenForUser(user *structs.User, jti string) (tokenStr string, err error) { // user.PrivateKey token := jwt.New(jwt.SigningMethodEdDSA) claims := token.Claims.(jwt.MapClaims) claims["exp"] = time.Now().Add(c.authExpireDur) claims["iss"] = "core-v1" claims["uid"] = user.ID claims["jti"] = jti claims["username"] = user.Username var roleIds []int64 for _, role := range user.Roles { roleIds = append(roleIds, role.ID) } claims["roles"] = roleIds tokenStr, err = token.SignedString(ed25519.PrivateKey(user.PrivateKey)) return } func (c *CoreService) stripConfidentialUserData(u *structs.User) { u.Password = "" u.PrivateKey = nil u.PublicKey = nil } func (c *CoreService) getSession(ctx context.Context) (sess *session.Session, err error) { fctx := http.GetFiberFromContext(ctx) sess, err = http.Session.Get(fctx) if err != nil { c.logger.Error("get session failed: ", err) exc := exceptions.NewCoreServicesException() exc.Code = 500 exc.Message = "unable to get session" exc.Parameters = map[string]string{ "inner": err.Error(), } err = exc return } if authToken := fctx.Request().Header.Peek("X-Hub-Auth-Token"); authToken != nil { // tbd } return } var ( ErrBadRequest = errors.New("bad request") ErrForbidden = errors.New("forbidden") ErrNotAuthenticated = errors.New("not authenticated") ErrInvalidState = errors.New("invalid state") ) func (c *CoreService) checkSession(sess *session.Session) (user *structs.User, err error) { userVal := sess.Get("user") if userVal == nil { err = ErrNotAuthenticated return } var ok bool user, ok = userVal.(*structs.User) if !ok { err = ErrNotAuthenticated return } tx := c.db.Where("id = ?", user.ID).First(user) if err = tx.Error; err != nil { // must check if the user is still valid / no db error // or we just revoke the session. if err = sess.Destroy(); err != nil { return } return } // TODO: also use the JWT ? return } func (c *CoreService) getAndCheckSession(ctx context.Context) (user *structs.User, err error) { var sess *session.Session if sess, err = c.getSession(ctx); err != nil { return } if user, err = c.checkSession(sess); err != nil { return } return } func (c *CoreService) getAndCheckSessionAdmin(ctx context.Context) (user *structs.User, err error) { if user, err = c.getAndCheckSession(ctx); err != nil { return } // user with type system if !c.isSystemUser(user) { err = ErrForbidden return } return } func (c *CoreService) isSystemUser(user *structs.User) bool { for _, role := range user.Roles { if role.RoleType == structs.RoleType_SYSTEM { return true } } return false } func (c *CoreService) alert2Str(alert *structs.AlertInfo) (s string) { var b []byte var err error if b, err = json.Marshal(alert); err != nil { s = "{}" return } s = string(b) return } func (c *CoreService) wrapServiceError(err error, customMessage ...string) error { if errVal, ok := err.(*exceptions.CoreServicesException); ok { return errVal } exc := exceptions.NewCoreServicesException() exc.Message = err.Error() switch err { case gorm.ErrRecordNotFound: exc.Code = 404 alert := &structs.AlertInfo{ Title: "Resource", Description: "Resource not found", } if len(customMessage) >= 2 { alert.Title = customMessage[0] alert.Description = customMessage[1] } else if len(customMessage) >= 1 { alert.Description = customMessage[0] } exc.Parameters = map[string]string{ "alert": c.alert2Str(alert), } case ErrForbidden: exc.Code = 403 exc.Parameters = map[string]string{ "alert": c.alert2Str(&structs.AlertInfo{ Title: "Access Error", Description: err.Error(), }), } case ErrBadRequest: exc.Code = 400 exc.Parameters = map[string]string{ "alert": c.alert2Str(&structs.AlertInfo{ Title: "Input Error", Description: err.Error(), }), } case ErrNotAuthenticated: exc.Code = 401 exc.Parameters = map[string]string{ "alert": c.alert2Str(&structs.AlertInfo{ Title: "Auth Error", Description: err.Error(), }), } case ErrInvalidState: exc.Code = 500 exc.Parameters = map[string]string{ "redirect": "/", "alert": c.alert2Str(&structs.AlertInfo{ Title: "Server Error", Description: err.Error(), }), } default: exc.Code = 500 exc.Parameters = map[string]string{ "alert": c.alert2Str(&structs.AlertInfo{ Title: "Server Error", Description: err.Error(), }), } } return exc } func (c *CoreService) checkRequest(ctx context.Context, v any) (err error) { err = c.vd.StructCtx(ctx, v) if err != nil { exc := exceptions.NewCoreServicesException() exc.Code = 400 exc.Message = err.Error() exc.Parameters = map[string]string{ "alert": c.alert2Str(&structs.AlertInfo{ Title: "Validation Error", Description: c.err2str(err), }), } err = exc return } return } func (c *CoreService) passwordHasher(plain string) string { return passwordHasher(plain, c.secret) } func (c *CoreService) err2str(err error) (s string) { var ve validator.ValidationErrors if errors.As(err, &ve) { var sb strings.Builder for i, f := range ve { sb.WriteString(f.Field()) sb.Write([]byte(` `)) sb.WriteString(f.Tag()) if i+1 < len(ve) { sb.Write([]byte(`, `)) } } return sb.String() } return err.Error() } // pagination2Offset correct pagination and return the offset func (c *CoreService) pagination2Offset(pagination *structs.Pagination) (offset int, limit int) { offset, limit = 0, 25 if pagination == nil { return } if pagination.Page <= 0 { pagination.Page = 1 } switch { case pagination.RowsPerPage > 100: pagination.RowsPerPage = 100 case pagination.RowsPerPage <= 0: pagination.RowsPerPage = int64(limit) } limit = int(pagination.RowsPerPage) offset = int(pagination.Page-1) * limit return } // ScPaginate gorm scope function func (c *CoreService) ScPaginate(pagination *structs.Pagination) scopeFunc { return func(db *gorm.DB) *gorm.DB { offset, limit := c.pagination2Offset(pagination) // use id > ? next time. return db.Limit(limit).Offset(offset) } } // ScSearchTerm gorm scope function func (c *CoreService) ScSearchTerm(searchTerm string, fields ...string) scopeFunc { // bloat FTS-like search :( var sb strings.Builder var values []any if len(fields) <= 0 { goto ret } for i, field := range fields { fmt.Fprintf(&sb, "%s LIKE ?", field) if i < len(fields)-1 { sb.WriteString(" OR ") } values = append(values, fmt.Sprintf("%%%s%%", searchTerm)) } ret: return func(db *gorm.DB) *gorm.DB { return db.Where(sb.String(), values...) } }