September 27, 2020

WebSocket Implemention With Go

使用Go来实现WebSocket协议

什么是WS协议

The WebSocket Protocol enables two-way communication between a client running untrusted code in a controlled environment to a remote host that has opted-in to communications from that code. The security model used for this is the origin-based security model commonly used by web browsers. The protocol consists of an opening handshake followed by basic message framing, layered over TCP.

The goal of this technology is to provide a mechanism for browser-based applications that need two-way communication with servers that does not rely on opening multiple HTTP connections (e.g., using XMLHttpRequest or <iframe>s and long.polling). - 摘自 RFC6455 Abstract.

用大白话就是: “基于TCP传输协议构建的全双工,用于Web浏览器和服务端进行通信的应用层协议”。那么同属于应用层协议,WebSocket和HTTP之间有何异同?

特点 HTTP WebSocket
OSI Layer 7 7
通信模式 单工/半双工/全双工,取决于HTTP版本 全双工
工作端口 80/443 80/443作为默认,但不限于
是否支持SSL
传输协议 TCP TCP
数据格式 JSON / TEXT / Binary Stream 等 JSON / TEXT / Binary Stream 等
协议Schema http/https ws/wss
最新版本 3 13

为什么要用WS协议

想一下,自己会在什么场景下考虑WebSocket而不是HTTP协议就明白了。在没有 WebSocket 之前,了实现即时通信有以下几种方案:

  • Short Polling
  • Long Polling
  • Streaming
  • SSE (Server Sent Event)

这部分的演进可以阅读 https://halfrost.com/websocket/ 本文也多有参考。

实现

本文的主要目标是为了介绍如何使用Go实现WebSocket协议,因此最好是具备以下能力:

  • 会使用 Wireshark 或者 tcpdump
  • 熟练使用 Go 语言开发
  • 能读懂RFC6455的英语能力,当然你用翻译也可以,至少别因为翻译错误而无法继续~
  • 对TCP有一定理解

阅读RFC6455,阅读RFC6455,阅读RFC6455!!! 这是一切的起点,没必要急于动手。

1. 理解WebSocket的协议帧和工作流程

整体流程:

时序图:

协议帧:

协议帧 是最重要的部分,理解了协议帧就相当于完成了50%,余下的就是撸代码了… 协议帧主要包括了以下几个部分:

  • 头部
    • FIN [1bit] 指示该帧是不是消息分片的最后一帧。分片-传送门
    • RSV 保留位 [3bit]
    • MASK 是否使用掩码 [1bit]
    • PAYLOAD LEN 数据部分的长度 [7bit] 如果等于 126/127(7bit=128-1)则说明启用 PAYLOAD LEN EXT(1/2) 部分来表明数据部分的长度。
    • PAYLOAD LEN EXT 仅当 PAYLOAD LEN 无法表示数据部分的长度时启用 [16bit/64bit]
    • MASKING KEY 掩码 [0/32bit] 取决是MASK是否设置。掩码-传送们
  • 数据, 长度等于 Payload Len或者Payload Len Ext,请忽略图中一个bit一个字符…

如果认真看图的话和看过RFC6455之后,你就会发现数据部分的真实长度,其实分为三种情况:

  • Payload Len 的值,小于 126时,没有扩展的Payload Len Ext部分,值本身表达数据部分的长度
  • Payload Len 的值,等于 126时,Payload Len Ext1 启用,长度16bit
  • Payload Len 的值,等于 127时,Payload Len Ext2 启用,长度64bit

2. 协议帧frame的基本处理

看过协议帧之后,接下来就是数据帧的处理,这部分需要的就是编码知识了。那么在Go里面如何去表达一个帧呢?一个帧在TCP眼里就是字节流,发送的时候是,接收的时候也是,因此帧的主要工作就是:将协议数据按照指定的格式塞到字节流里面去,亦或者是从字节流中解析到我们用于表达的数据结构中去。就像一个翻译官,按照特定的语法来回翻译。

// Frame 这里使用uint16来承载,只是因为统一,方便计算移位,因为整个头部最重要的部分就是16bit的长度。
type Frame struct {
    Fin    uint16 // 1 bit
    RSV1   uint16 // 1 bit, 0
    RSV2   uint16 // 1 bit, 0
    RSV3   uint16 // 1 bit, 0
    OpCode OpCode // 4 bits
    Mask   uint16 // 1 bit

    // Payload length:  7 bits, 7+16 bits, or 7+64 bits
    //
    // if PayloadLen = 0 - 125, actual_payload_length = PayloadLen
    // if PayloadLen = 126, 	actual_payload_length = PayloadExtendLen[:16]
    // if PayloadLen = 127, 	actual_payload_length = PayloadExtendLen[:]
    PayloadLen       uint16 // 7 bits
    PayloadExtendLen uint64 // 64 bits

    MaskingKey uint32 // 32 bits
    Payload    []byte // no limit by RFC6455
}

翻译的部分代码如下:

这里对于 “大小端序” 没有概念的,可以先自行查阅资料。

// 将帧翻译到字节流
func encodeFrameTo(frm *Frame) []byte {
    buf := make([]byte, 2, minFrameHeaderSize+8)

    var (
        part1 uint16 // from FIN to PayloadLen
    )

    part1 |= frm.Fin << finOffset
    part1 |= frm.RSV1 << rsv1Offset
    part1 |= frm.RSV2 << rsv2Offset
    part1 |= frm.RSV3 << rsv3Offset
    part1 |= uint16(frm.OpCode) << opcodeOffset
    part1 |= frm.Mask << maskOffset
    part1 |= frm.PayloadLen << payloadLenOffset

    // start from 0th byte
    // fill part1 into 2 byte
    binary.BigEndian.PutUint16(buf[:2], part1)

    // FIXED: fill payloadExtendLen into 8 byte
    switch frm.PayloadLen {
    case 126:
        payloadExtendBuf := make([]byte, 2)
        binary.BigEndian.PutUint16(payloadExtendBuf[:2], uint16(frm.PayloadExtendLen))
        buf = append(buf, payloadExtendBuf...)
    case 127:
        payloadExtendBuf := make([]byte, 8)
        binary.BigEndian.PutUint64(payloadExtendBuf[:8], frm.PayloadExtendLen)
        buf = append(buf, payloadExtendBuf...)
    }

    // FIXED: if not mask, then no set masking key
    if frm.Mask == 1 {
        // fill fmtMaskingKey into 4 byte
        maskingKeyBuf := make([]byte, 4)
        binary.BigEndian.PutUint32(maskingKeyBuf[:4], frm.MaskingKey)
        buf = append(buf, maskingKeyBuf...)
    }

    // header done, start writing body
    buf = append(buf, frm.Payload...)

    return buf
}

// 解析帧的头部(16bit)信息,因为WebSocket协议帧中:PayloadLen 是变长,
// MaskingKey 也是有或者没有,但是都可以通过16bit的数据来获得准确的结果。
// 解析过程也就是写入的逆向操作。
func parseFrameHeader(header []byte) *Frame {
    var (
        frm   = new(Frame)
        part1 = binary.BigEndian.Uint16(header[:2])
    )

    frm.Fin = (part1 & finMask) >> finOffset
    frm.RSV1 = (part1 & rsv1Mask) >> rsv1Offset
    frm.RSV2 = (part1 & rsv2Mask) >> rsv2Offset
    frm.RSV3 = (part1 & rsv3Mask) >> rsv3Offset
    frm.OpCode = OpCode((part1 & opcodeMask) >> opcodeOffset)
    frm.Mask = (part1 & maskMask) >> maskOffset
    frm.PayloadLen = (part1 & payloadLenMask) >> payloadLenOffset

    return frm
}

3. WebSocket链接的定义

Conn 是对TCP链接的封装再配合上协议,来对客户端和服务端的提供功能。其中我个人觉得最重要的功能是:如何从TCP字节流中读到一个完整的WebSocket消息 Conn.ReadMessage()

type Conn struct {
    conn  net.Conn       // 底层TCP链接
    bufRD *bufio.Reader  // 读
    bufWR *bufio.Writer  // 写
    // 省略部分不是特别重要的字段
}

// 构造websocket.Conn
func newConn(netconn net.Conn, isServer bool) (*Conn, error) {
    c := Conn{
        conn:   netconn,
        bufRD:  bufio.NewReaderSize(netconn, 65535), // 65535B = 64KB
        bufWR:  bufio.NewWriter(netconn),
    }

    return &c, nil
}


// ReadMessage . it will block to read message
func (c *Conn) ReadMessage() (mt MessageType, msg []byte, err error) {
    frm, err := c.readFrame()
    if err != nil {
        debugErrorf("Conn.ReadMessage failed to c.readFrame, err=%v", err)
        return NoFrame, nil, err
    }
    mt = MessageType(frm.OpCode)

    // 根据读到的帧判断是否还有后续的帧,如果有分片,那就读完将payload组装到一起。
    buf := bytes.NewBuffer(nil)
    buf.Write(frm.Payload)
    for !frm.isFinal() {
        if frm, err = c.readFrame(); err != nil {
            debugErrorf("Conn.ReadMessage failed to c.readFrame, err=%v", err)
            return NoFrame, nil, err
        }
        buf.Write(frm.Payload)
    }

    msg = buf.Bytes()
    return
}

// 从缓冲区读取指定字节数量的数据
func (c *Conn) read(n int) ([]byte, error) {
    p, err := c.bufRD.Peek(n)
    if err == io.EOF {
        err = ErrUnexpectedEOF
        return nil, err
    }
    _, _ = c.bufRD.Discard(len(p))
    return p, err
}

func (c *Conn) readFrame() (*Frame, error) {
    // 阻塞地读2Byte的数据
    p, err := c.read(2)
    if err != nil {
        debugErrorf("Conn.readFrame failed to c.read(header), err=%v", err)
        return nil, err
    }

    // 解析WebSocket帧头部
    frmWithoutPayload := parseFrameHeader(p)
    logger.Debugf("Conn.readFrame got frmWithoutPayload=%+v", frmWithoutPayload)

    var (
        payloadExtendLen uint64 // this could be non exist
        remaining        uint64
    )

    // 根据PayloadLen来读取不同字节数的扩展长度
    // 126 -> 2B
    // 127 -> 8B
    switch frmWithoutPayload.PayloadLen {
    case 126:
        // has 16bit + 32bit = 6B
        p, err = c.read(2)
        if err != nil {
            debugErrorf("Conn.readFrame failed to c.read(2) payloadlen with 16bit, err=%v", err)
            return nil, err
        }
        payloadExtendLen = uint64(binary.BigEndian.Uint16(p[:2]))
        remaining = payloadExtendLen
    case 127:
        // has 64bit + 32bit = 12B
        p, err = c.read(8)
        if err != nil {
            debugErrorf("Conn.readFrame failed to c.read(8) payloadlen with 16bit, err=%v", err)
            return nil, err
        }
        payloadExtendLen = binary.BigEndian.Uint64(p[:8])
        remaining = payloadExtendLen
    default:
        remaining = uint64(frmWithoutPayload.PayloadLen)
    }
    frmWithoutPayload.PayloadExtendLen = payloadExtendLen

    // get masking key
    if frmWithoutPayload.Mask == 1 {
        // only 32bit masking key to read
        p, err = c.read(4)
        if err != nil {
            debugErrorf("Conn.readFrame failed to c.read(header), err=%v", err)
            return nil, err
        }
        frmWithoutPayload.MaskingKey = binary.BigEndian.Uint32(p)
    }

    // 省略frame校验过程
    // 读取payload数据并填充到frame中去
    var (
        payload = make([]byte, 0, remaining)
    )

    logger.Debugf("Conn.readFrame c.read(%d) into payload data", remaining)
    for remaining > 65535 {
        // true: bufio.Reader can read 65535 byte as most at once
        p, err := c.read(65535)
        if err != nil {
            debugErrorf("Conn.readFrame failed to c.read(payload), err=%v", err)
            return nil, err
        }
        payload = append(payload, p...)
        remaining -= 65535
    }

    // 读取剩余部分的payload
    p, err = c.read(int(remaining))
    if err != nil {
        debugErrorf("Conn.readFrame failed to c.read(payload), err=%v", err)
        return nil, err
    }
    payload = append(payload, p...)
    frmWithoutPayload.setPayload(payload)

    // 处理ping pong close 帧
    switch frmWithoutPayload.OpCode {
    case opCodeText, opCodeBinary, opCodeContinuation:
        // pass
    case opCodePing:
        err = c.replyPing(frmWithoutPayload)
    case opCodePong:
        err = c.replyPong(frmWithoutPayload)
    case opCodeClose:
        err = c.handleClose(frmWithoutPayload)
    }

    return frmWithoutPayload, err
}

4. 服务端和客户端的定义

到这一步,已经把底层的工作都完成了:定义协议帧协议帧翻译Conn约定和封装等工作。现在可以开始设计和实现顶层API了。 服务端API,我参考了gorilla/websocket,定义了一个 Upgrader 来将HTTP升级到 Websocket

// Upgrader std.HTTP / fasthttp / gin etc
type Upgrader struct {
    CheckOrigin func(req *http.Request) bool

    Timeout time.Duration
}

// 升级协议。如果遇到了hijack错误,可以看这里的连接,希望有所帮助
// https://stackoverflow.com/questions/32657603/why-do-i-get-the-error-message-http-response-write-on-hijacked-connection
//
func (ug Upgrader) Upgrade(w http.ResponseWriter, req *http.Request, fn func(conn *Conn)) error {
    // 设置超时上下文
    ctx, cancel := context.WithTimeout(req.Context(), timeout)
    req = req.WithContext(ctx)
    defer cancel()

    // RFC6455 完成握手检查
    // almost checking is about headers
    if err := ug.handshakeCheck(w, req); err != nil {
        debugErrorf("Upgrader.Upgrade failed to ug.handshakeCheck, err=%v", err)
        return ug.returnError(w, http.StatusBadRequest, err.Error())
    }

    // !!! 重要,获取HTTP对应的底层TCP链接
    h := w.(http.Hijacker)
    netconn, brw, err = .Hijack()
    if err != nil {
        debugErrorf("Upgrader.Upgrade failed to h.Hijack, err=%v", err)
        _ = ug.returnError(w, http.StatusInternalServerError, err.Error())
        return nil
    }

    // 省略请求头处理 ...

    // 发送HTTP响应
    if err = hackHandshakeResponse(brw.Writer, respHeaders, "101"); err != nil {
        return err
    }
    
    // 启动一个goroutine来处理该链接的消息,ConnPool / Reactor 的实现,可以在这里调整
    conn, _ := newConn(netconn, true)
    go func() {
        defer func() {
            if err, ok := recover().(error); ok {
                logger.Errorf("Upgrader.Upgrade fn panic: err=%v", err)
                debug.PrintStack()
            }
        }()

        fn(conn)
    }()

    return nil
}

客户端的话就更简单了,建立一个TCP连接,向服务端发起HTTP升级到Websocket的请求,等握手通过那么连接就建立完成了,就可以对websocket.Conn进行读写操作了。


5. 建链和握手

在上一小节已经带过了这部分,这里重点想要介绍下http请求到websocket的升级过程。服务端例子如下:

func main() {
    http.HandleFunc("/echo", echo)

    if err := http.ListenAndServe(":8080", nil); err != nil {
        log.Fatal(err)
    }
}

func echo(w http.ResponseWriter, req *http.Request) {
    websocketHdl := func(conn *websocket.Conn) {
        for {
            // 读消息
            mt, message, err := conn.ReadMessage()
            // 发送消息
            err = conn.SendMessage(string(message))
        }
        log.Info("conn finished")
    }

    // 调用 upgrader 升级,并调用连接处理方法 - goroutine 保持
    err := upgrader.Upgrade(w, req, websocketHdl)
    if err != nil {
        log.Errorf("upgrade error, err=%v", err)
        return
    }

    log.Infof("conn upgrade done")
}

由上述代码我们可以明白,所谓的升级过程是使用HTTP协议来完成,对于服务端更友好,很容易在现有的HTTP服务中加上一个WebSocket服务。握手 - 传送门。用大白话描述就是:

  1. 客户端通过 schema://host:port/path/to/websocket 这样一个地址找到服务端 [建立TCP连接]
  2. 客户端发起一个HTTP请求(携带特殊的请求头)[客户端发起握手]
  3. 服务端通过验证后将该连接转换为websocket连接保持在服务端 [服务端握手检查和升级]
  4. 服务端握手完成后,同时向客户端发送一个HTTP响应(成功或失败)[服务端发送响应]
  5. 客户端根据服务端的响应处理该连接 [客户端处理响应]
  6. 握手完成

6. 完善细节

前面几步完成,WebSocket的框架就已经成型了,剩下的工作就是根据协议完善。比如对关闭帧的处理,Ping/Pong帧的处理等等。这里简单举例说明 Ping / Pong 的处理。

func (c *Conn) readFrame() (*Frame, error) {
    // ignore cases
    
    // handle with close, ping, pong frame
    switch frmWithoutPayload.OpCode {
        // ignore cases ...
    case opCodePing:
        err = c.replyPing(frmWithoutPayload)
    case opCodePong:
        err = c.replyPong(frmWithoutPayload)
        // ignore some cases ...
    }

    return frmWithoutPayload, err
}

// Ping conn send a ping packet to another side.
func (c *Conn) Ping() (err error) {
    return c.sendControlFrame(opCodePing, []byte("ping"))
}

// replyPing work for Conn to reply ping packet. frame MUST contains 125 Byte or-
// less payload.
func (c *Conn) replyPing(frm *Frame) (err error) {
    return c.pong(frm.Payload)
}

// pong .
func (c *Conn) pong(pingPayload []byte) (err error) {
    return c.sendControlFrame(opCodePong, pingPayload)
}

// replyPong frame MUST contains same payload with PING frame payload
func (c *Conn) replyPong(frm *Frame) (err error) {
    // if receive pong frame, try to call pongHandler
    if c.pongHandler != nil {
        c.pongHandler(string(frm.Payload))
    }

    return nil
}

总结

实现WebSocket协议并没有什么难点,只要你读完RFC6455就行了,动手去实现只是为了加深认识,尤其是对于网络和协议的认识。实现起来简单另外一个原因是因为毕竟是应用层协议,基于TCP可靠传输。传输层协议对我们屏蔽了大量的网络细节问题~,如果想要挑战自己,可以尝试实现TCP或者UDP🐶。

水平有限,如有错误,欢迎勘误指正🙏。

参考资料