package dao import ( "context" "fmt" "strings" "time" "go-common/app/service/main/antispam/util" "go-common/library/database/sql" "go-common/library/log" ) const ( columnRules = "id, area, limit_type, limit_scope, dur_sec, allowed_counts, ctime, mtime" selectRuleCountsSQL = `SELECT COUNT(1) FROM rate_limit_rules %s` selectRulesByCondSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules %s` selectRuleByIDsSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE id IN(%s)` selectRulesByAreaSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE area = %s` selectRulesByAreaAndTypeSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE area = %s AND limit_type = %s` selectRulesByAreaAndTypeAndScopeSQL = `SELECT ` + columnRules + ` FROM rate_limit_rules WHERE area = %s AND limit_type = %s AND limit_scope = %s` insertRuleSQL = `INSERT INTO rate_limit_rules(area, limit_type, limit_scope, dur_sec, allowed_counts) VALUES(?, ?, ?, ?, ?)` updateRuleSQL = `UPDATE rate_limit_rules SET dur_sec = ?, allowed_counts = ?, mtime = ? WHERE area = ? AND limit_type = ? AND limit_scope = ?` ) // Rule . type Rule struct { ID int64 `db:"id"` Area int `db:"area"` LimitType int `db:"limit_type"` LimitScope int `db:"limit_scope"` DurationSec int64 `db:"dur_sec"` AllowedCounts int64 `db:"allowed_counts"` CTime time.Time `db:"ctime"` MTime time.Time `db:"mtime"` } // RuleDaoImpl . type RuleDaoImpl struct{} const ( // LimitTypeDefaultLimit . LimitTypeDefaultLimit int = iota // LimitTypeRestrictLimit . LimitTypeRestrictLimit // LimitTypeWhite . LimitTypeWhite // LimitTypeBlack . LimitTypeBlack ) const ( // LimitScopeGlobal . LimitScopeGlobal int = iota // LimitScopeLocal . LimitScopeLocal ) // NewRuleDao . func NewRuleDao() *RuleDaoImpl { return &RuleDaoImpl{} } func updateRule(ctx context.Context, executer Executer, r *Rule) error { _, err := executer.Exec(ctx, updateRuleSQL, r.DurationSec, r.AllowedCounts, time.Now(), r.Area, r.LimitType, r.LimitScope, ) if err != nil { log.Error("%v", err) return err } return nil } func insertRule(ctx context.Context, executer Executer, r *Rule) error { res, err := executer.Exec(ctx, insertRuleSQL, r.Area, r.LimitType, r.LimitScope, r.DurationSec, r.AllowedCounts, ) if err != nil { log.Error("%v", err) return err } lastID, err := res.LastInsertId() if err != nil { log.Error("%v", err) return err } r.ID = lastID return nil } // GetByCond . func (*RuleDaoImpl) GetByCond(ctx context.Context, cond *Condition) (rules []*Rule, totalCounts int64, err error) { sqlConds := make([]string, 0) if cond.Area != "" { sqlConds = append(sqlConds, fmt.Sprintf("area = %s", cond.Area)) } if cond.State != "" { sqlConds = append(sqlConds, fmt.Sprintf("state = %s", cond.State)) } var optionSQL string if len(sqlConds) > 0 { optionSQL = fmt.Sprintf("WHERE %s", strings.Join(sqlConds, " AND ")) } var limitSQL string if cond.Pagination != nil { queryCountsSQL := fmt.Sprintf(selectRuleCountsSQL, optionSQL) totalCounts, err = GetTotalCounts(ctx, db, queryCountsSQL) if err != nil { return nil, 0, err } offset, limit := cond.OffsetLimit(totalCounts) if limit == 0 { return nil, 0, ErrResourceNotExist } limitSQL = fmt.Sprintf("LIMIT %d, %d", offset, limit) } if cond.OrderBy != "" { optionSQL = fmt.Sprintf("%s ORDER BY %s %s", optionSQL, cond.OrderBy, cond.Order) } if limitSQL != "" { optionSQL = fmt.Sprintf("%s %s", optionSQL, limitSQL) } querySQL := fmt.Sprintf(selectRulesByCondSQL, optionSQL) log.Info("OptionSQL(%s), GetByCondSQL(%s)", optionSQL, querySQL) rules, err = queryRules(ctx, db, querySQL) if err != nil { return nil, totalCounts, err } return rules, totalCounts, nil } // Update . func (rdi *RuleDaoImpl) Update(ctx context.Context, r *Rule) (*Rule, error) { if err := updateRule(ctx, db, r); err != nil { return nil, err } return rdi.GetByAreaAndTypeAndScope(ctx, &Condition{ Area: fmt.Sprintf("%d", r.Area), LimitType: fmt.Sprintf("%d", r.LimitType), LimitScope: fmt.Sprintf("%d", r.LimitScope), }) } // Insert . func (rdi *RuleDaoImpl) Insert(ctx context.Context, r *Rule) (*Rule, error) { if err := insertRule(ctx, db, r); err != nil { return nil, err } return rdi.GetByID(ctx, r.ID) } // GetByID . func (rdi *RuleDaoImpl) GetByID(ctx context.Context, id int64) (*Rule, error) { rs, err := rdi.GetByIDs(ctx, []int64{id}) if err != nil { return nil, err } if rs[0] == nil { return nil, ErrResourceNotExist } return rs[0], nil } // GetByIDs . func (*RuleDaoImpl) GetByIDs(ctx context.Context, ids []int64) ([]*Rule, error) { rs, err := queryRules(ctx, db, fmt.Sprintf(selectRuleByIDsSQL, util.IntSliToSQLVarchars(ids))) if err != nil { return nil, err } res := make([]*Rule, len(ids)) for i, id := range ids { for _, r := range rs { if r.ID == id { res[i] = r } } } return res, nil } // GetByAreaAndLimitType . func (*RuleDaoImpl) GetByAreaAndLimitType(ctx context.Context, cond *Condition) ([]*Rule, error) { return queryRules(ctx, db, fmt.Sprintf(selectRulesByAreaAndTypeSQL, cond.Area, cond.LimitType)) } // GetByAreaAndTypeAndScope . func (*RuleDaoImpl) GetByAreaAndTypeAndScope(ctx context.Context, cond *Condition) (*Rule, error) { rs, err := queryRules(ctx, db, fmt.Sprintf(selectRulesByAreaAndTypeAndScopeSQL, cond.Area, cond.LimitType, cond.LimitScope, )) if err != nil { return nil, err } return rs[0], nil } // GetByArea . func (*RuleDaoImpl) GetByArea(ctx context.Context, cond *Condition) ([]*Rule, error) { return queryRules(ctx, db, fmt.Sprintf(selectRulesByAreaSQL, cond.Area)) } func queryRules(ctx context.Context, q Querier, rawSQL string) ([]*Rule, error) { log.Info("Query sql: %q", rawSQL) rows, err := q.Query(ctx, rawSQL) if err == sql.ErrNoRows { err = ErrResourceNotExist } if err != nil { log.Error("Error: %v, RawSQL: %s", err, rawSQL) return nil, err } defer rows.Close() rs, err := mapRowToRules(rows) if err != nil { return nil, err } if len(rs) == 0 { return nil, ErrResourceNotExist } return rs, nil } func mapRowToRules(rows *sql.Rows) (rs []*Rule, err error) { for rows.Next() { r := Rule{} err = rows.Scan( &r.ID, &r.Area, &r.LimitType, &r.LimitScope, &r.DurationSec, &r.AllowedCounts, &r.CTime, &r.MTime, ) if err != nil { log.Error("%v", err) return nil, err } rs = append(rs, &r) } if err = rows.Err(); err != nil { log.Error("%v", err) return nil, err } return rs, nil }