2019-04-22 02:59:20 +00:00

257 lines
5.3 KiB
Go

package websocket
import (
"encoding/binary"
"errors"
"fmt"
"io"
"go-common/app/service/main/broadcast/libs/bufio"
)
const (
// Frame header byte 0 bits from Section 5.2 of RFC 6455
finBit = 1 << 7
rsv1Bit = 1 << 6
rsv2Bit = 1 << 5
rsv3Bit = 1 << 4
opBit = 0x0f
// Frame header byte 1 bits from Section 5.2 of RFC 6455
maskBit = 1 << 7
lenBit = 0x7f
continuationFrame = 0
continuationFrameMaxRead = 100
)
// The message types are defined in RFC 6455, section 11.8.
const (
// TextMessage denotes a text data message. The text message payload is
// interpreted as UTF-8 encoded text data.
TextMessage = 1
// BinaryMessage denotes a binary data message.
BinaryMessage = 2
// CloseMessage denotes a close control message. The optional message
// payload contains a numeric code and text. Use the FormatCloseMessage
// function to format a close message payload.
CloseMessage = 8
// PingMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PingMessage = 9
// PongMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PongMessage = 10
)
var (
// ErrMessageClose close control message
ErrMessageClose = errors.New("close control message")
// ErrMessageMaxRead continuation frrame max read
ErrMessageMaxRead = errors.New("continuation frame max read")
)
// Conn represents a WebSocket connection.
type Conn struct {
rwc io.ReadWriteCloser
r *bufio.Reader
w *bufio.Writer
}
// new connection
func newConn(rwc io.ReadWriteCloser, r *bufio.Reader, w *bufio.Writer) *Conn {
return &Conn{rwc: rwc, r: r, w: w}
}
// WriteMessage write a message by type.
func (c *Conn) WriteMessage(msgType int, msg []byte) (err error) {
if err = c.WriteHeader(msgType, len(msg)); err != nil {
return
}
err = c.WriteBody(msg)
return
}
// WriteHeader write header frame.
func (c *Conn) WriteHeader(msgType int, length int) (err error) {
var h []byte
if h, err = c.w.Peek(2); err != nil {
return
}
// 1.First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
h[0] = 0
h[0] |= finBit | byte(msgType)
// 2.Second byte. Mask/Payload len(7bits)
h[1] = 0
switch {
case length <= 125:
// 7 bits
h[1] |= byte(length)
case length < 65536:
// 16 bits
h[1] |= 126
if h, err = c.w.Peek(2); err != nil {
return
}
binary.BigEndian.PutUint16(h, uint16(length))
default:
// 64 bits
h[1] |= 127
if h, err = c.w.Peek(8); err != nil {
return
}
binary.BigEndian.PutUint64(h, uint64(length))
}
return
}
// WriteBody write a message body.
func (c *Conn) WriteBody(b []byte) (err error) {
if len(b) > 0 {
_, err = c.w.Write(b)
}
return
}
// Peek write peek.
func (c *Conn) Peek(n int) ([]byte, error) {
return c.w.Peek(n)
}
// Flush flush writer buffer
func (c *Conn) Flush() error {
return c.w.Flush()
}
// ReadMessage read a message.
func (c *Conn) ReadMessage() (op int, payload []byte, err error) {
var (
fin bool
finOp, n int
partPayload []byte
)
for {
// read frame
if fin, op, partPayload, err = c.readFrame(); err != nil {
return
}
switch op {
case BinaryMessage, TextMessage, continuationFrame:
if fin && len(payload) == 0 {
return op, partPayload, nil
}
// continuation frame
payload = append(payload, partPayload...)
if op != continuationFrame {
finOp = op
}
// final frame
if fin {
op = finOp
return
}
case PingMessage:
// handler ping
if err = c.WriteMessage(PongMessage, partPayload); err != nil {
return
}
case PongMessage:
// handler pong
case CloseMessage:
// handler close
err = ErrMessageClose
return
default:
err = fmt.Errorf("unknown control message, fin=%t, op=%d", fin, op)
return
}
if n > continuationFrameMaxRead {
err = ErrMessageMaxRead
return
}
n++
}
}
func (c *Conn) readFrame() (fin bool, op int, payload []byte, err error) {
var (
b byte
p []byte
mask bool
maskKey []byte
payloadLen int64
)
// 1.First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
b, err = c.r.ReadByte()
if err != nil {
return
}
// final frame
fin = (b & finBit) != 0
// rsv MUST be 0
if rsv := b & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return false, 0, nil, fmt.Errorf("unexpected reserved bits rsv1=%d, rsv2=%d, rsv3=%d", b&rsv1Bit, b&rsv2Bit, b&rsv3Bit)
}
// op code
op = int(b & opBit)
// 2.Second byte. Mask/Payload len(7bits)
b, err = c.r.ReadByte()
if err != nil {
return
}
// is mask payload
mask = (b & maskBit) != 0
// payload length
switch b & lenBit {
case 126:
// 16 bits
if p, err = c.r.Pop(2); err != nil {
return
}
payloadLen = int64(binary.BigEndian.Uint16(p))
case 127:
// 64 bits
if p, err = c.r.Pop(8); err != nil {
return
}
payloadLen = int64(binary.BigEndian.Uint64(p))
default:
// 7 bits
payloadLen = int64(b & lenBit)
}
// read mask key
if mask {
maskKey, err = c.r.Pop(4)
if err != nil {
return
}
}
// read payload
if payloadLen > 0 {
if payload, err = c.r.Pop(int(payloadLen)); err != nil {
return
}
if mask {
maskBytes(maskKey, 0, payload)
}
}
return
}
// Close close the connection.
func (c *Conn) Close() error {
return c.rwc.Close()
}
func maskBytes(key []byte, pos int, b []byte) int {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}