271 lines
6.5 KiB
Go
271 lines
6.5 KiB
Go
package dao
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"go-common/app/service/main/antispam/util"
|
|
|
|
"go-common/library/database/sql"
|
|
"go-common/library/log"
|
|
)
|
|
|
|
const (
|
|
columnRules = "id, area, limit_type, limit_scope, dur_sec, allowed_counts, ctime, mtime"
|
|
|
|
selectRuleCountsSQL = `SELECT COUNT(1) FROM rate_limit_rules %s`
|
|
selectRulesByCondSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules %s`
|
|
selectRuleByIDsSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE id IN(%s)`
|
|
selectRulesByAreaSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE area = %s`
|
|
selectRulesByAreaAndTypeSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE area = %s AND limit_type = %s`
|
|
selectRulesByAreaAndTypeAndScopeSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE area = %s AND limit_type = %s AND limit_scope = %s`
|
|
|
|
insertRuleSQL = `INSERT INTO rate_limit_rules(area, limit_type, limit_scope, dur_sec, allowed_counts) VALUES(?, ?, ?, ?, ?)`
|
|
updateRuleSQL = `UPDATE rate_limit_rules SET dur_sec = ?, allowed_counts = ?, mtime = ? WHERE area = ? AND limit_type = ? AND limit_scope = ?`
|
|
)
|
|
|
|
// Rule .
|
|
type Rule struct {
|
|
ID int64 `db:"id"`
|
|
Area int `db:"area"`
|
|
LimitType int `db:"limit_type"`
|
|
LimitScope int `db:"limit_scope"`
|
|
DurationSec int64 `db:"dur_sec"`
|
|
AllowedCounts int64 `db:"allowed_counts"`
|
|
|
|
CTime time.Time `db:"ctime"`
|
|
MTime time.Time `db:"mtime"`
|
|
}
|
|
|
|
// RuleDaoImpl .
|
|
type RuleDaoImpl struct{}
|
|
|
|
const (
|
|
// LimitTypeDefaultLimit .
|
|
LimitTypeDefaultLimit int = iota
|
|
// LimitTypeRestrictLimit .
|
|
LimitTypeRestrictLimit
|
|
// LimitTypeWhite .
|
|
LimitTypeWhite
|
|
// LimitTypeBlack .
|
|
LimitTypeBlack
|
|
)
|
|
|
|
const (
|
|
// LimitScopeGlobal .
|
|
LimitScopeGlobal int = iota
|
|
// LimitScopeLocal .
|
|
LimitScopeLocal
|
|
)
|
|
|
|
// NewRuleDao .
|
|
func NewRuleDao() *RuleDaoImpl {
|
|
return &RuleDaoImpl{}
|
|
}
|
|
|
|
func updateRule(ctx context.Context, executer Executer, r *Rule) error {
|
|
_, err := executer.Exec(ctx,
|
|
updateRuleSQL,
|
|
|
|
r.DurationSec,
|
|
r.AllowedCounts,
|
|
time.Now(),
|
|
|
|
r.Area,
|
|
r.LimitType,
|
|
r.LimitScope,
|
|
)
|
|
if err != nil {
|
|
log.Error("%v", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func insertRule(ctx context.Context, executer Executer, r *Rule) error {
|
|
res, err := executer.Exec(ctx,
|
|
insertRuleSQL,
|
|
|
|
r.Area,
|
|
r.LimitType,
|
|
r.LimitScope,
|
|
r.DurationSec,
|
|
r.AllowedCounts,
|
|
)
|
|
if err != nil {
|
|
log.Error("%v", err)
|
|
return err
|
|
}
|
|
lastID, err := res.LastInsertId()
|
|
if err != nil {
|
|
log.Error("%v", err)
|
|
return err
|
|
}
|
|
r.ID = lastID
|
|
return nil
|
|
}
|
|
|
|
// GetByCond .
|
|
func (*RuleDaoImpl) GetByCond(ctx context.Context, cond *Condition) (rules []*Rule, totalCounts int64, err error) {
|
|
sqlConds := make([]string, 0)
|
|
if cond.Area != "" {
|
|
sqlConds = append(sqlConds, fmt.Sprintf("area = %s", cond.Area))
|
|
}
|
|
if cond.State != "" {
|
|
sqlConds = append(sqlConds, fmt.Sprintf("state = %s", cond.State))
|
|
}
|
|
var optionSQL string
|
|
if len(sqlConds) > 0 {
|
|
optionSQL = fmt.Sprintf("WHERE %s", strings.Join(sqlConds, " AND "))
|
|
}
|
|
|
|
var limitSQL string
|
|
if cond.Pagination != nil {
|
|
queryCountsSQL := fmt.Sprintf(selectRuleCountsSQL, optionSQL)
|
|
totalCounts, err = GetTotalCounts(ctx, db, queryCountsSQL)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
offset, limit := cond.OffsetLimit(totalCounts)
|
|
if limit == 0 {
|
|
return nil, 0, ErrResourceNotExist
|
|
}
|
|
limitSQL = fmt.Sprintf("LIMIT %d, %d", offset, limit)
|
|
}
|
|
if cond.OrderBy != "" {
|
|
optionSQL = fmt.Sprintf("%s ORDER BY %s %s", optionSQL, cond.OrderBy, cond.Order)
|
|
}
|
|
if limitSQL != "" {
|
|
optionSQL = fmt.Sprintf("%s %s", optionSQL, limitSQL)
|
|
}
|
|
querySQL := fmt.Sprintf(selectRulesByCondSQL, optionSQL)
|
|
log.Info("OptionSQL(%s), GetByCondSQL(%s)", optionSQL, querySQL)
|
|
rules, err = queryRules(ctx, db, querySQL)
|
|
if err != nil {
|
|
return nil, totalCounts, err
|
|
}
|
|
return rules, totalCounts, nil
|
|
}
|
|
|
|
// Update .
|
|
func (rdi *RuleDaoImpl) Update(ctx context.Context, r *Rule) (*Rule, error) {
|
|
if err := updateRule(ctx, db, r); err != nil {
|
|
return nil, err
|
|
}
|
|
return rdi.GetByAreaAndTypeAndScope(ctx, &Condition{
|
|
Area: fmt.Sprintf("%d", r.Area),
|
|
LimitType: fmt.Sprintf("%d", r.LimitType),
|
|
LimitScope: fmt.Sprintf("%d", r.LimitScope),
|
|
})
|
|
}
|
|
|
|
// Insert .
|
|
func (rdi *RuleDaoImpl) Insert(ctx context.Context, r *Rule) (*Rule, error) {
|
|
if err := insertRule(ctx, db, r); err != nil {
|
|
return nil, err
|
|
}
|
|
return rdi.GetByID(ctx, r.ID)
|
|
}
|
|
|
|
// GetByID .
|
|
func (rdi *RuleDaoImpl) GetByID(ctx context.Context, id int64) (*Rule, error) {
|
|
rs, err := rdi.GetByIDs(ctx, []int64{id})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if rs[0] == nil {
|
|
return nil, ErrResourceNotExist
|
|
}
|
|
return rs[0], nil
|
|
}
|
|
|
|
// GetByIDs .
|
|
func (*RuleDaoImpl) GetByIDs(ctx context.Context, ids []int64) ([]*Rule, error) {
|
|
rs, err := queryRules(ctx, db, fmt.Sprintf(selectRuleByIDsSQL, util.IntSliToSQLVarchars(ids)))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res := make([]*Rule, len(ids))
|
|
for i, id := range ids {
|
|
for _, r := range rs {
|
|
if r.ID == id {
|
|
res[i] = r
|
|
}
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// GetByAreaAndLimitType .
|
|
func (*RuleDaoImpl) GetByAreaAndLimitType(ctx context.Context, cond *Condition) ([]*Rule, error) {
|
|
return queryRules(ctx, db, fmt.Sprintf(selectRulesByAreaAndTypeSQL, cond.Area, cond.LimitType))
|
|
}
|
|
|
|
// GetByAreaAndTypeAndScope .
|
|
func (*RuleDaoImpl) GetByAreaAndTypeAndScope(ctx context.Context, cond *Condition) (*Rule, error) {
|
|
rs, err := queryRules(ctx, db, fmt.Sprintf(selectRulesByAreaAndTypeAndScopeSQL,
|
|
cond.Area,
|
|
cond.LimitType,
|
|
cond.LimitScope,
|
|
))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return rs[0], nil
|
|
}
|
|
|
|
// GetByArea .
|
|
func (*RuleDaoImpl) GetByArea(ctx context.Context, cond *Condition) ([]*Rule, error) {
|
|
return queryRules(ctx, db, fmt.Sprintf(selectRulesByAreaSQL, cond.Area))
|
|
}
|
|
|
|
func queryRules(ctx context.Context, q Querier, rawSQL string) ([]*Rule, error) {
|
|
log.Info("Query sql: %q", rawSQL)
|
|
rows, err := q.Query(ctx, rawSQL)
|
|
if err == sql.ErrNoRows {
|
|
err = ErrResourceNotExist
|
|
}
|
|
if err != nil {
|
|
log.Error("Error: %v, RawSQL: %s", err, rawSQL)
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
rs, err := mapRowToRules(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(rs) == 0 {
|
|
return nil, ErrResourceNotExist
|
|
}
|
|
return rs, nil
|
|
}
|
|
|
|
func mapRowToRules(rows *sql.Rows) (rs []*Rule, err error) {
|
|
for rows.Next() {
|
|
r := Rule{}
|
|
err = rows.Scan(
|
|
&r.ID,
|
|
&r.Area,
|
|
&r.LimitType,
|
|
&r.LimitScope,
|
|
&r.DurationSec,
|
|
&r.AllowedCounts,
|
|
&r.CTime,
|
|
&r.MTime,
|
|
)
|
|
if err != nil {
|
|
log.Error("%v", err)
|
|
return nil, err
|
|
}
|
|
rs = append(rs, &r)
|
|
}
|
|
if err = rows.Err(); err != nil {
|
|
log.Error("%v", err)
|
|
return nil, err
|
|
}
|
|
return rs, nil
|
|
}
|