bilibili-backup/app/infra/databus/tcp/conn.go
2019-04-22 02:59:20 +00:00

262 lines
5.4 KiB
Go

package tcp
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
)
type conn struct {
// conn
conn net.Conn
// Read
readTimeout time.Duration
br *bufio.Reader
// Write
writeTimeout time.Duration
bw *bufio.Writer
// Scratch space for formatting argument length.
// '*' or '$', length, "\r\n"
lenScratch [32]byte
// Scratch space for formatting integers and floats.
numScratch [40]byte
}
// newConn returns a new connection for the given net connection.
func newConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) *conn {
return &conn{
conn: netConn,
readTimeout: readTimeout,
writeTimeout: writeTimeout,
br: bufio.NewReaderSize(netConn, _readBufSize),
bw: bufio.NewWriterSize(netConn, _writeBufSize),
}
}
// Read read data from connection
func (c *conn) Read() (cmd string, args [][]byte, err error) {
var (
ln, cn int
bs []byte
)
if c.readTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
// start read
if bs, err = c.readLine(); err != nil {
return
}
if len(bs) < 2 {
err = fmt.Errorf("read error data(%s) from connection", bs)
return
}
// maybe a cmd that without any params is received,such as: QUIT
if strings.ToLower(string(bs)) == _quit {
cmd = _quit
return
}
// get param number
if ln, err = parseLen(bs[1:]); err != nil {
return
}
args = make([][]byte, 0, ln-1)
for i := 0; i < ln; i++ {
if cn, err = c.readLen(_protoBulk); err != nil {
return
}
if bs, err = c.readData(cn); err != nil {
return
}
if i == 0 {
cmd = strings.ToLower(string(bs))
continue
}
args = append(args, bs)
}
return
}
// WriteError write error to connection and close connection
func (c *conn) WriteError(err error) {
if c.writeTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if err = c.Write(proto{prefix: _protoErr, message: err.Error()}); err != nil {
c.Close()
return
}
c.Flush()
c.Close()
}
// Write write data to connection
func (c *conn) Write(p proto) (err error) {
if c.writeTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// start write
switch p.prefix {
case _protoStr:
err = c.writeStatus(p.message)
case _protoErr:
err = c.writeError(p.message)
case _protoInt:
err = c.writeInt64(int64(p.integer))
case _protoBulk:
// c.writeString(p.message)
err = c.writeBytes([]byte(p.message))
case _protoArray:
err = c.writeLen(p.prefix, p.integer)
}
return
}
// Flush flush connection
func (c *conn) Flush() error {
if c.writeTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
return c.bw.Flush()
}
// Close close connection
func (c *conn) Close() error {
return c.conn.Close()
}
// parseLen parses bulk string and array lengths.
func parseLen(p []byte) (int, error) {
if len(p) == 0 {
return -1, errors.New("malformed length")
}
if p[0] == '-' && len(p) == 2 && p[1] == '1' {
// handle $-1 and $-1 null replies.
return -1, nil
}
var n int
for _, b := range p {
n *= 10
if b < '0' || b > '9' {
return -1, errors.New("illegal bytes in length")
}
n += int(b - '0')
}
return n, nil
}
func (c *conn) readLine() ([]byte, error) {
p, err := c.br.ReadBytes('\n')
if err == bufio.ErrBufferFull {
return nil, errors.New("long response line")
}
if err != nil {
return nil, err
}
i := len(p) - 2
if i < 0 || p[i] != '\r' {
return nil, errors.New("bad response line terminator")
}
return p[:i], nil
}
func (c *conn) readLen(prefix byte) (int, error) {
ls, err := c.readLine()
if err != nil {
return 0, err
}
if len(ls) < 2 {
return 0, errors.New("illegal bytes in length")
}
if ls[0] != prefix {
return 0, errors.New("illegal bytes in length")
}
return parseLen(ls[1:])
}
func (c *conn) readData(n int) ([]byte, error) {
if n > _maxValueSize {
return nil, errors.New("exceeding max value limit")
}
buf := make([]byte, n+2)
r, err := io.ReadFull(c.br, buf)
if err != nil {
return nil, err
}
if n != r-2 {
return nil, errors.New("invalid bytes in len")
}
return buf[:n], err
}
func (c *conn) writeLen(prefix byte, n int) error {
c.lenScratch[len(c.lenScratch)-1] = '\n'
c.lenScratch[len(c.lenScratch)-2] = '\r'
i := len(c.lenScratch) - 3
for {
c.lenScratch[i] = byte('0' + n%10)
i--
n = n / 10
if n == 0 {
break
}
}
c.lenScratch[i] = prefix
_, err := c.bw.Write(c.lenScratch[i:])
return err
}
func (c *conn) writeStatus(s string) (err error) {
c.bw.WriteByte(_protoStr)
c.bw.WriteString(s)
_, err = c.bw.WriteString("\r\n")
return
}
func (c *conn) writeError(s string) (err error) {
c.bw.WriteByte(_protoErr)
c.bw.WriteString(s)
_, err = c.bw.WriteString("\r\n")
return
}
func (c *conn) writeInt64(n int64) (err error) {
c.bw.WriteByte(_protoInt)
c.bw.Write(strconv.AppendInt(c.numScratch[:0], n, 10))
_, err = c.bw.WriteString("\r\n")
return
}
func (c *conn) writeString(s string) (err error) {
c.writeLen(_protoBulk, len(s))
c.bw.WriteString(s)
_, err = c.bw.WriteString("\r\n")
return
}
func (c *conn) writeBytes(s []byte) (err error) {
if len(s) == 0 {
c.bw.WriteByte('$')
c.bw.Write(_nullBulk)
} else {
c.writeLen(_protoBulk, len(s))
c.bw.Write(s)
}
_, err = c.bw.WriteString("\r\n")
return
}
func (c *conn) writeStrings(ss []string) (err error) {
c.writeLen(_protoArray, len(ss))
for _, s := range ss {
if err = c.writeString(s); err != nil {
return
}
}
return
}