用Go寫一個簡單的TCP server or client 模型

對Conn封裝的基本思路

go內置了net包已經很好的封裝了socket通訊。然而在實際使用中,由於net/Conn的Read/Write方法是堵塞的原因,必須將其放入單獨的goroutine中進行處理。

我們先簡單的整理下思路,對於連接(Conn)的處理,我們可以開啓2條goroutine進行處理,一條用於堵塞的Read的處理,另一條進行Write的處理。

這裏必須指出,其實Write本身就是線程安全的,也就是我們在別的任何地方都可以進行Write,但是Write是堵塞的,所以考慮到這點,我們還是將其放入一個單獨的goroutine中進行處理。

這樣設計的原因在於Conn是支持同時Read/Write的,這樣我們的基本的Conn的模型就成型了。對於服務端或者客戶端而言,我們只需要封裝對應過來的Conn即可,Conn的讀寫goroutine進行處理,並將獲得的事件拋向外部。

那麼我們就按這個思路來實現一個簡單的Connection封裝,該封裝支持線程安全的寫,並且支持解包操作。

package tcpnetwork

import (
    "errors"
    "log"
    "net"
    "time"
)

const (
    kConnStatus_None = iota
    kConnStatus_Connected
    kConnStatus_Disconnected
)

const (
    kConnEvent_None = iota
    kConnEvent_Connected
    kConnEvent_Disconnected
    kConnEvent_Data
    kConnEvent_Close
)

const (
    kConnConf_DefaultSendTimeoutSec = 5
    kConnConf_MaxReadBufferLength   = 0xffff // 0xffff
)

type Connection struct {
    conn                net.Conn
    status              int
    connId              int
    sendMsgQueue        chan []byte
    sendTimeoutSec      int
    eventQueue          IEventQueue
    streamProtocol      IStreamProtocol
    maxReadBufferLength int
    userdata            interface{}
    from                int
    readTimeoutSec      int
}

func newConnection(c net.Conn, sendBufferSize int, eq IEventQueue) *Connection {
    return &Connection{
        conn:                c,
        status:              kConnStatus_None,
        connId:              0,
        sendMsgQueue:        make(chan []byte, sendBufferSize),
        sendTimeoutSec:      kConnConf_DefaultSendTimeoutSec,
        maxReadBufferLength: kConnConf_MaxReadBufferLength,
        eventQueue:          eq,
    }
}

type ConnEvent struct {
    EventType int
    Conn      *Connection
    Data      []byte
}

func newConnEvent(et int, c *Connection, d []byte) *ConnEvent {
    return &ConnEvent{
        EventType: et,
        Conn:      c,
        Data:      d,
    }
}

//  directly close, packages in queue will not be sent
func (this *Connection) close() {
    if kConnStatus_Connected != this.status {
        return
    }

    this.conn.Close()
    this.status = kConnStatus_Disconnected
}

func (this *Connection) Close() {
    if this.status != kConnStatus_Connected {
        return
    }

    select {
    case this.sendMsgQueue <- nil:
        {
            //  nothing
        }
    case <-time.After(time.Duration(this.sendTimeoutSec)):
        {
            //  timeout, close the connection
            this.close()
            log.Printf("Con[%d] send message timeout, close it", this.connId)
        }
    }
}

func (this *Connection) pushEvent(et int, d []byte) {
    if nil == this.eventQueue {
        log.Println("Nil event queue")
        return
    }
    this.eventQueue.Push(newConnEvent(et, this, d))
}

func (this *Connection) GetStatus() int {
    return this.status
}

func (this *Connection) setStatus(stat int) {
    this.status = stat
}

func (this *Connection) GetConnId() int {
    return this.connId
}

func (this *Connection) SetConnId(id int) {
    this.connId = id
}

func (this *Connection) GetUserdata() interface{} {
    return this.userdata
}

func (this *Connection) SetUserdata(ud interface{}) {
    this.userdata = ud
}

func (this *Connection) SetReadTimeoutSec(sec int) {
    this.readTimeoutSec = sec
}

func (this *Connection) GetReadTimeoutSec() int {
    return this.readTimeoutSec
}

func (this *Connection) setStreamProtocol(sp IStreamProtocol) {
    this.streamProtocol = sp
}

func (this *Connection) sendRaw(msg []byte) {
    if this.status != kConnStatus_Connected {
        return
    }

    select {
    case this.sendMsgQueue <- msg:
        {
            //  nothing
        }
    case <-time.After(time.Duration(this.sendTimeoutSec)):
        {
            //  timeout, close the connection
            this.close()
            log.Printf("Con[%d] send message timeout, close it", this.connId)
        }
    }
}

func (this *Connection) Send(msg []byte, cpy bool) {
    if this.status != kConnStatus_Connected {
        return
    }

    buf := msg
    if cpy {
        msgCopy := make([]byte, len(msg))
        copy(msgCopy, msg)
        buf = msgCopy
    }

    select {
    case this.sendMsgQueue <- buf:
        {
            //  nothing
        }
    case <-time.After(time.Duration(this.sendTimeoutSec)):
        {
            //  timeout, close the connection
            this.close()
            log.Printf("Con[%d] send message timeout, close it", this.connId)
        }
    }
}

//  run a routine to process the connection
func (this *Connection) run() {
    go this.routineMain()
}

func (this *Connection) routineMain() {
    defer func() {
        //  routine end
        log.Printf("Routine of connection[%d] quit", this.connId)
        e := recover()
        if e != nil {
            log.Println("Panic:", e)
        }

        //  close the connection
        this.close()

        //  free channel
        close(this.sendMsgQueue)
        this.sendMsgQueue = nil

        //  post event
        this.pushEvent(kConnEvent_Disconnected, nil)
    }()

    if nil == this.streamProtocol {
        log.Println("Nil stream protocol")
        return
    }
    this.streamProtocol.Init()

    //  connected
    this.pushEvent(kConnEvent_Connected, nil)
    this.status = kConnStatus_Connected

    go this.routineSend()
    this.routineRead()
}

func (this *Connection) routineSend() error {
    defer func() {
        log.Println("Connection", this.connId, " send loop return")
    }()

    for {
        select {
        case evt, ok := <-this.sendMsgQueue:
            {
                if !ok {
                    //  channel closed, quit
                    return nil
                }

                if nil == evt {
                    log.Println("User disconnect")
                    this.close()
                    return nil
                }

                var err error

                headerBytes := this.streamProtocol.SerializeHeader(evt)
                if nil != headerBytes {
                    //  write header first
                    _, err = this.conn.Write(headerBytes)
                    if err != nil {
                        log.Println("Conn write error:", err)
                        return err
                    }
                }

                _, err = this.conn.Write(evt)
                if err != nil {
                    log.Println("Conn write error:", err)
                    return err
                }
            }
        }
    }

    return nil
}

func (this *Connection) routineRead() error {
    //  default buffer
    buf := make([]byte, this.maxReadBufferLength)

    for {
        msg, err := this.unpack(buf)
        if err != nil {
            log.Println("Conn read error:", err)
            return err
        }

        this.pushEvent(kConnEvent_Data, msg)
    }

    return nil
}

func (this *Connection) unpack(buf []byte) ([]byte, error) {
    //  read head
    if 0 != this.readTimeoutSec {
        this.conn.SetReadDeadline(time.Now().Add(time.Duration(this.readTimeoutSec) * time.Second))
    }
    headBuf := buf[:this.streamProtocol.GetHeaderLength()]
    _, err := this.conn.Read(headBuf)
    if err != nil {
        return nil, err
    }

    //  check length
    packetLength := this.streamProtocol.UnserializeHeader(headBuf)
    if packetLength > this.maxReadBufferLength ||
        0 == packetLength {
        return nil, errors.New("The stream data is too long")
    }

    //  read body
    if 0 != this.readTimeoutSec {
        this.conn.SetReadDeadline(time.Now().Add(time.Duration(this.readTimeoutSec) * time.Second))
    }
    bodyLength := packetLength - this.streamProtocol.GetHeaderLength()
    _, err = this.conn.Read(buf[:bodyLength])
    if err != nil {
        return nil, err
    }

    //  ok
    msg := make([]byte, bodyLength)
    copy(msg, buf[:bodyLength])
    if 0 != this.readTimeoutSec {
        this.conn.SetReadDeadline(time.Time{})
    }

    return msg, nil
}

這就是簡單的Conn封裝,在拋出事件這部,定義了一個interface來接收事件。

package tcpnetwork

type IEventQueue interface {
    Push(*ConnEvent)
    Pop() *ConnEvent
}

type IStreamProtocol interface {
    //  Init
    Init()
    //  get the header length of the stream
    GetHeaderLength() int
    //  read the header length of the stream
    UnserializeHeader([]byte) int
    //  format header
    SerializeHeader([]byte) []byte
}

type IEventHandler interface {
    OnConnected(evt *ConnEvent)
    OnDisconnected(evt *ConnEvent)
    OnRecv(evt *ConnEvent)
}

我們只要在實現對應的方法,就可以接收事件和讀取事件了。

Server/Client 端的實現

Server

我們已經封裝好了Conn,那麼接下來的工作將會簡單很多。我們來封裝一個TCPNetwork的結構。

對於Server端來說,它基本的步驟就是

  • 監聽端口
  • 開啓accept線程
  • 得到Conn,並生成Connection,開啓Read/Write線程

第一步很簡單,net包已封裝

ls, err := net.Listen("tcp", addr)
if nil != err {
    return err
}

//  accept
this.listener = ls
go this.acceptRoutine()
return nil

我們在listen成功後,開啓了一個goroutine來不斷的進行死循環來等待連接的接入。

func (this *TCPNetwork) acceptRoutine() {
    for {
        conn, err := this.listener.Accept()
        if err != nil {
            log.Println("accept routine quit.error:", err)
            return
        }

        //  process conn event
        connection := this.createConn(conn)
        connection.SetReadTimeoutSec(this.readTimeoutSec)
        connection.from = 0
        connection.run()
    }
}

得到了Connection後,我們只需要讓它的處理routine跑起來即可。然而我們需要對應的Connection拋過來的事件,於是我們在TCPNetwork中實現IEventQueue的2個方法,這樣我們的TCPNetwork就可以接收對應的事件的拋入,也可以讀取,大家讀到這裏也就知道了,最適合實現這個的就是golang的神器:channel。

我們來爲TCPNetwork定義一個chan *ConnEvent,並實現接口的方法。

func (this *TCPNetwork) Push(evt *ConnEvent) {
    if nil == this.eventQueue {
        return
    }
    this.eventQueue <- evt
}

func (this *TCPNetwork) Pop() *ConnEvent {
    evt, ok := <-this.eventQueue
    if !ok {
        //  event queue already closed
        this.eventQueue = nil
        return nil
    }

    return evt
}

其實基本的邏輯已經完成了,對於我們來說,只需要關心eventQueue裏的內容就行了,這就屬於上層邏輯處理了。這樣一個簡單的TCPNetwork就封裝好了。

Client

Client端基本沒什麼好說的,net包Connect成功後會獲得一個Conn,然後對於這個Conn的處理其實和Server端一樣了。

func (this *TCPNetwork) Connect(addr string) error {
    conn, err := net.Dial("tcp", addr)
    if nil != err {
        return err
    }

    connection := this.createConn(conn)
    connection.from = 1
    connection.run()

    return nil
}

使用方法

對於外部來說,使用很簡單,這裏貼上一個簡單的echo server example.

package main

import (
    "log"

    "github.com/sryanyuan/tcpnetwork"
)

type TSShutdown struct {
    network *tcpnetwork.TCPNetwork
}

func NewTSShutdown() *TSShutdown {
    t := &TSShutdown{}
    t.network = tcpnetwork.NewTCPNetwork(1024, tcpnetwork.NewStreamProtocol4())
    return t
}

func (this *TSShutdown) OnConnected(evt *tcpnetwork.ConnEvent) {
    log.Println("connected ", evt.Conn.GetConnId())
}

func (this *TSShutdown) OnDisconnected(evt *tcpnetwork.ConnEvent) {
    log.Println("disconnected ", evt.Conn.GetConnId())
}

func (this *TSShutdown) OnRecv(evt *tcpnetwork.ConnEvent) {
    log.Println("recv ", evt.Conn.GetConnId(), evt.Data)

    evt.Conn.Send(evt.Data, false)
}

func (this *TSShutdown) Run() {
    err := this.network.Listen("127.0.0.1:2222")

    if err != nil {
        log.Println(err)
        return
    }

    this.network.ServeWithHandler(this)
    log.Println("done")
}


package main

import (
    "fmt"
    "log"
)

func main() {
    defer func() {
        e := recover()
        if e != nil {
            log.Println(e)
        }
        var inp int
        fmt.Scanln(&inp)
    }()
    tsshutdown := NewTSShutdown()
    tsshutdown.Run()
}

總結

我們其實已經實現了一個簡單的tcp封裝了,支持server/client連接,至於其中心跳的細節、封包解包的細節等等,這裏就不多介紹了,可以通過閱讀源碼來理解。

該封裝可以見我的Github TCPNetwork

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章