對Conn封裝的基本思路
go內置了net包已經很好的封裝了socket通訊。然而在實際使用中,由於net/Conn的Read/Write方法是堵塞的原因,必須將其放入單獨的goroutine中進行處理。
我們先簡單的整理下思路,對於連接(Conn)的處理,我們可以開啓2條goroutine進行處理,一條用於堵塞的Read的處理,另一條進行Write的處理。
這裏必須指出,其實Write本身就是線程安全的,也就是我們在別的任何地方都可以進行Write,但是Write是堵塞的,所以考慮到這點,我們還是將其放入一個單獨的goroutine中進行處理。
這樣設計的原因在於Conn是支持同時Read/Write的,這樣我們的基本的Conn的模型就成型了。對於服務端或者客戶端而言,我們只需要封裝對應過來的Conn即可,Conn的讀寫goroutine進行處理,並將獲得的事件拋向外部。
那麼我們就按這個思路來實現一個簡單的Connection封裝,該封裝支持線程安全的寫,並且支持解包操作。
package tcpnetwork
import (
"errors"
"log"
"net"
"time"
)
const (
kConnStatus_None = iota
kConnStatus_Connected
kConnStatus_Disconnected
)
const (
kConnEvent_None = iota
kConnEvent_Connected
kConnEvent_Disconnected
kConnEvent_Data
kConnEvent_Close
)
const (
kConnConf_DefaultSendTimeoutSec = 5
kConnConf_MaxReadBufferLength = 0xffff // 0xffff
)
type Connection struct {
conn net.Conn
status int
connId int
sendMsgQueue chan []byte
sendTimeoutSec int
eventQueue IEventQueue
streamProtocol IStreamProtocol
maxReadBufferLength int
userdata interface{}
from int
readTimeoutSec int
}
func newConnection(c net.Conn, sendBufferSize int, eq IEventQueue) *Connection {
return &Connection{
conn: c,
status: kConnStatus_None,
connId: 0,
sendMsgQueue: make(chan []byte, sendBufferSize),
sendTimeoutSec: kConnConf_DefaultSendTimeoutSec,
maxReadBufferLength: kConnConf_MaxReadBufferLength,
eventQueue: eq,
}
}
type ConnEvent struct {
EventType int
Conn *Connection
Data []byte
}
func newConnEvent(et int, c *Connection, d []byte) *ConnEvent {
return &ConnEvent{
EventType: et,
Conn: c,
Data: d,
}
}
// directly close, packages in queue will not be sent
func (this *Connection) close() {
if kConnStatus_Connected != this.status {
return
}
this.conn.Close()
this.status = kConnStatus_Disconnected
}
func (this *Connection) Close() {
if this.status != kConnStatus_Connected {
return
}
select {
case this.sendMsgQueue <- nil:
{
// nothing
}
case <-time.After(time.Duration(this.sendTimeoutSec)):
{
// timeout, close the connection
this.close()
log.Printf("Con[%d] send message timeout, close it", this.connId)
}
}
}
func (this *Connection) pushEvent(et int, d []byte) {
if nil == this.eventQueue {
log.Println("Nil event queue")
return
}
this.eventQueue.Push(newConnEvent(et, this, d))
}
func (this *Connection) GetStatus() int {
return this.status
}
func (this *Connection) setStatus(stat int) {
this.status = stat
}
func (this *Connection) GetConnId() int {
return this.connId
}
func (this *Connection) SetConnId(id int) {
this.connId = id
}
func (this *Connection) GetUserdata() interface{} {
return this.userdata
}
func (this *Connection) SetUserdata(ud interface{}) {
this.userdata = ud
}
func (this *Connection) SetReadTimeoutSec(sec int) {
this.readTimeoutSec = sec
}
func (this *Connection) GetReadTimeoutSec() int {
return this.readTimeoutSec
}
func (this *Connection) setStreamProtocol(sp IStreamProtocol) {
this.streamProtocol = sp
}
func (this *Connection) sendRaw(msg []byte) {
if this.status != kConnStatus_Connected {
return
}
select {
case this.sendMsgQueue <- msg:
{
// nothing
}
case <-time.After(time.Duration(this.sendTimeoutSec)):
{
// timeout, close the connection
this.close()
log.Printf("Con[%d] send message timeout, close it", this.connId)
}
}
}
func (this *Connection) Send(msg []byte, cpy bool) {
if this.status != kConnStatus_Connected {
return
}
buf := msg
if cpy {
msgCopy := make([]byte, len(msg))
copy(msgCopy, msg)
buf = msgCopy
}
select {
case this.sendMsgQueue <- buf:
{
// nothing
}
case <-time.After(time.Duration(this.sendTimeoutSec)):
{
// timeout, close the connection
this.close()
log.Printf("Con[%d] send message timeout, close it", this.connId)
}
}
}
// run a routine to process the connection
func (this *Connection) run() {
go this.routineMain()
}
func (this *Connection) routineMain() {
defer func() {
// routine end
log.Printf("Routine of connection[%d] quit", this.connId)
e := recover()
if e != nil {
log.Println("Panic:", e)
}
// close the connection
this.close()
// free channel
close(this.sendMsgQueue)
this.sendMsgQueue = nil
// post event
this.pushEvent(kConnEvent_Disconnected, nil)
}()
if nil == this.streamProtocol {
log.Println("Nil stream protocol")
return
}
this.streamProtocol.Init()
// connected
this.pushEvent(kConnEvent_Connected, nil)
this.status = kConnStatus_Connected
go this.routineSend()
this.routineRead()
}
func (this *Connection) routineSend() error {
defer func() {
log.Println("Connection", this.connId, " send loop return")
}()
for {
select {
case evt, ok := <-this.sendMsgQueue:
{
if !ok {
// channel closed, quit
return nil
}
if nil == evt {
log.Println("User disconnect")
this.close()
return nil
}
var err error
headerBytes := this.streamProtocol.SerializeHeader(evt)
if nil != headerBytes {
// write header first
_, err = this.conn.Write(headerBytes)
if err != nil {
log.Println("Conn write error:", err)
return err
}
}
_, err = this.conn.Write(evt)
if err != nil {
log.Println("Conn write error:", err)
return err
}
}
}
}
return nil
}
func (this *Connection) routineRead() error {
// default buffer
buf := make([]byte, this.maxReadBufferLength)
for {
msg, err := this.unpack(buf)
if err != nil {
log.Println("Conn read error:", err)
return err
}
this.pushEvent(kConnEvent_Data, msg)
}
return nil
}
func (this *Connection) unpack(buf []byte) ([]byte, error) {
// read head
if 0 != this.readTimeoutSec {
this.conn.SetReadDeadline(time.Now().Add(time.Duration(this.readTimeoutSec) * time.Second))
}
headBuf := buf[:this.streamProtocol.GetHeaderLength()]
_, err := this.conn.Read(headBuf)
if err != nil {
return nil, err
}
// check length
packetLength := this.streamProtocol.UnserializeHeader(headBuf)
if packetLength > this.maxReadBufferLength ||
0 == packetLength {
return nil, errors.New("The stream data is too long")
}
// read body
if 0 != this.readTimeoutSec {
this.conn.SetReadDeadline(time.Now().Add(time.Duration(this.readTimeoutSec) * time.Second))
}
bodyLength := packetLength - this.streamProtocol.GetHeaderLength()
_, err = this.conn.Read(buf[:bodyLength])
if err != nil {
return nil, err
}
// ok
msg := make([]byte, bodyLength)
copy(msg, buf[:bodyLength])
if 0 != this.readTimeoutSec {
this.conn.SetReadDeadline(time.Time{})
}
return msg, nil
}
這就是簡單的Conn封裝,在拋出事件這部,定義了一個interface來接收事件。
package tcpnetwork
type IEventQueue interface {
Push(*ConnEvent)
Pop() *ConnEvent
}
type IStreamProtocol interface {
// Init
Init()
// get the header length of the stream
GetHeaderLength() int
// read the header length of the stream
UnserializeHeader([]byte) int
// format header
SerializeHeader([]byte) []byte
}
type IEventHandler interface {
OnConnected(evt *ConnEvent)
OnDisconnected(evt *ConnEvent)
OnRecv(evt *ConnEvent)
}
我們只要在實現對應的方法,就可以接收事件和讀取事件了。
Server/Client 端的實現
Server
我們已經封裝好了Conn,那麼接下來的工作將會簡單很多。我們來封裝一個TCPNetwork的結構。
對於Server端來說,它基本的步驟就是
- 監聽端口
- 開啓accept線程
- 得到Conn,並生成Connection,開啓Read/Write線程
第一步很簡單,net包已封裝
ls, err := net.Listen("tcp", addr)
if nil != err {
return err
}
// accept
this.listener = ls
go this.acceptRoutine()
return nil
我們在listen成功後,開啓了一個goroutine來不斷的進行死循環來等待連接的接入。
func (this *TCPNetwork) acceptRoutine() {
for {
conn, err := this.listener.Accept()
if err != nil {
log.Println("accept routine quit.error:", err)
return
}
// process conn event
connection := this.createConn(conn)
connection.SetReadTimeoutSec(this.readTimeoutSec)
connection.from = 0
connection.run()
}
}
得到了Connection後,我們只需要讓它的處理routine跑起來即可。然而我們需要對應的Connection拋過來的事件,於是我們在TCPNetwork中實現IEventQueue的2個方法,這樣我們的TCPNetwork就可以接收對應的事件的拋入,也可以讀取,大家讀到這裏也就知道了,最適合實現這個的就是golang的神器:channel。
我們來爲TCPNetwork定義一個chan *ConnEvent,並實現接口的方法。
func (this *TCPNetwork) Push(evt *ConnEvent) {
if nil == this.eventQueue {
return
}
this.eventQueue <- evt
}
func (this *TCPNetwork) Pop() *ConnEvent {
evt, ok := <-this.eventQueue
if !ok {
// event queue already closed
this.eventQueue = nil
return nil
}
return evt
}
其實基本的邏輯已經完成了,對於我們來說,只需要關心eventQueue裏的內容就行了,這就屬於上層邏輯處理了。這樣一個簡單的TCPNetwork就封裝好了。
Client
Client端基本沒什麼好說的,net包Connect成功後會獲得一個Conn,然後對於這個Conn的處理其實和Server端一樣了。
func (this *TCPNetwork) Connect(addr string) error {
conn, err := net.Dial("tcp", addr)
if nil != err {
return err
}
connection := this.createConn(conn)
connection.from = 1
connection.run()
return nil
}
使用方法
對於外部來說,使用很簡單,這裏貼上一個簡單的echo server example.
package main
import (
"log"
"github.com/sryanyuan/tcpnetwork"
)
type TSShutdown struct {
network *tcpnetwork.TCPNetwork
}
func NewTSShutdown() *TSShutdown {
t := &TSShutdown{}
t.network = tcpnetwork.NewTCPNetwork(1024, tcpnetwork.NewStreamProtocol4())
return t
}
func (this *TSShutdown) OnConnected(evt *tcpnetwork.ConnEvent) {
log.Println("connected ", evt.Conn.GetConnId())
}
func (this *TSShutdown) OnDisconnected(evt *tcpnetwork.ConnEvent) {
log.Println("disconnected ", evt.Conn.GetConnId())
}
func (this *TSShutdown) OnRecv(evt *tcpnetwork.ConnEvent) {
log.Println("recv ", evt.Conn.GetConnId(), evt.Data)
evt.Conn.Send(evt.Data, false)
}
func (this *TSShutdown) Run() {
err := this.network.Listen("127.0.0.1:2222")
if err != nil {
log.Println(err)
return
}
this.network.ServeWithHandler(this)
log.Println("done")
}
package main
import (
"fmt"
"log"
)
func main() {
defer func() {
e := recover()
if e != nil {
log.Println(e)
}
var inp int
fmt.Scanln(&inp)
}()
tsshutdown := NewTSShutdown()
tsshutdown.Run()
}
總結
我們其實已經實現了一個簡單的tcp封裝了,支持server/client連接,至於其中心跳的細節、封包解包的細節等等,這裏就不多介紹了,可以通過閱讀源碼來理解。
該封裝可以見我的Github TCPNetwork