打造先進的內存KV數據庫-5 TCP偵聽

TCP偵聽

作爲支持集羣的數據庫,必定要與多個客戶端交互信息,不可能讓數據庫與所有客戶共享地址空間(雖然這樣性能好),所以需要使用TCP協議進行交互數據,(UDP協議不可靠。。。棄用),C語言的TCP庫其實還好,但是對於高併發和並行的處理不如Go,而且併發鎖機制比較難寫,所以使用Go寫了服務器和客戶端調用C的庫,目前版本沒有什麼身份驗證,之後會加上。

代碼實現

//server.go
package main
// #cgo LDFLAGS: -L ./lib -lmonkeyS
// #include "./lib/core.h"
// #include <stdlib.h>
import "C"
import (
    "unsafe"
    _"fmt"
    "net"
    "strings"
)

func main() {
    str := []byte("monkey")
    str = append(str,0)
    C.CreateDB((*C.char)(unsafe.Pointer(&str[0])))  //創建基礎數據庫
    servicePort := ":1517"
    tcpAddr,err := net.ResolveTCPAddr("tcp4",servicePort)
    if err != nil {
        panic(err)
    }
    l,err := net.ListenTCP("tcp",tcpAddr)   //偵聽TCP
    if err != nil {
        panic(err)
    }
    for{
        conn,err := l.Accept()
        if err != nil {
            panic(err)
        }
        go Handler(conn)
    }
}

func Handler(conn net.Conn) {

    str := []byte("monkey")                         //環境變量-當前數據庫
    db := C.SwitchDB((*C.char)(unsafe.Pointer(&str[0])))
    for {               
        buff := []byte{}
        buf := make([]byte,1024)
        length,err := conn.Read(buf)
        total := uint32(0); //前4個字節保存消息長度
        for i := 0;i < 4;i++ {
            total <<= 8;
            total += uint32(buf[i]);
        }
        //fmt.Println("Message length:",total)
        buff = append(buff,buf[4:]...)
        total -= uint32(length)
        for total > 0 {
            length,err = conn.Read(buf)
            total -= uint32(length)
            buff = append(buff,buf...)
        }
        if err != nil {
            conn.Close()
            break
        }
        TranslateMessage(conn,&db,buff)                     //解析消息
    }

}

func TranslateMessage(conn net.Conn,db **C.Database,message []byte) {
    command := string(message)
    params := strings.Split(command," ")
    //fmt.Println(params)
    response := []byte{}
    if params[0] == "set" {
        r := C.Set(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))),(unsafe.Pointer(&([]byte(params[2]))[0])))
        for i := 0;;i++ {
            response = append(response,byte(r.msg[i]))
            if response[i] == 0 { break; }
        }

    }else if params[0] == "get" {
        r := C.Get(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
        // for i := 0;;i++ {
        //  response = append(response,byte(r.msg[i]))
        //  if response[i] == 0 { break; }
        // }
        if(int(r.code) == 0) {
            for i := 0;;i++ {
                response = append(response,byte(*(*C.char)(unsafe.Pointer((uintptr(r.pData)+uintptr(i))))))
                if response[i] == 0 { break; }
            }
        }else {
            // for i := 0;;i++ {
            // response = append(response,byte(r.msg[i]))
            // if response[i] == 0 { break; }
            // }
        }

    }else if params[0] == "delete" || params[0] == "remove" {
        r := C.Delete(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
        for i := 0;;i++ {
            response = append(response,byte(r.msg[i]))
            if response[i] == 0 { break; }
        }

    }else if params[0] == "createdb" {
        d := C.CreateDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
        if d != nil {
            *db = d
            response = []byte("Already exist,switched\n")
        }else {
            response = []byte("Created\n")
        }
    }else if params[0] == "switchdb" {
        d := C.SwitchDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
        if d != nil {
            *db = d
            response = []byte("ok\n")
        }else {
            response = []byte("fail\n")
        }
    }else if params[0] == "dropdb" {
        *db = C.DropDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
    }else if strings.EqualFold("listdb",params[0]) {
        r := C.ListDB()
        for i := 0;i < 1024;i++ {
            b := byte(*(*C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(r))+uintptr(i))))
            response = append(response,b)
            if(b == 0){ break; }
        }
        C.free(unsafe.Pointer(r))
    }else {
        //fmt.Println("unkown command:",params[0])
    }
    total := len(response) + 4
    header := make([]byte,4)
    i := 0
    for total > 0 {
        header[3-i] = byte(total % 256)
        total /= 256
        i++
    }
    response = append(header,response...)
    conn.Write(response)
}
//Client.go
package main
import "net"
import "fmt"
func main() {
    tcpAddr, err := net.ResolveTCPAddr("tcp4", "127.0.0.1:1517")  
    if err != nil {
        panic(err)
    }
    conn, err := net.DialTCP("tcp", nil, tcpAddr)  
    if err != nil {
        panic(err)
    }

    for {
        buf1 := ""
        buf2 := ""
        buf3 := ""
        buf := ""
        fmt.Print("monkey>")
        fmt.Scanf("%s",&buf1)
        if buf1 == "set" {
            fmt.Scanf("%s",&buf2)
            fmt.Scanf("%s",&buf3)
            buf = buf1 + " " + buf2 + " " + buf3
        }else if buf1 == "get"{
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "remove" || buf1 == "delete" {
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "createdb"{
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "switchdb"{
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "dropdb"{
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "listdb"{
            buf = buf1 + " "
        }else if buf1 == "exit"{
            fmt.Println("Bye!")
            break;
        }
        total := uint32(0)
        total = uint32(len(buf) + 4)
        header := make([]byte,4)
        i := 0
        for total > 0 {
            header[3-i] = byte(total % 256)
            total /= 256
            i++
        } 
        conn.Write(append(header,([]byte(buf))...))

        buff := []byte{}
        buff2 := make([]byte,1024)
        length,_ := conn.Read(buff2)
        total = uint32(0);  //前4個字節保存消息長度
        for i := 0;i < 4;i++ {
            total <<= 8;
            total += uint32(buff2[i]);
        }
        buff = append(buff,buff2[4:]...)
        total -= uint32(length)
        for total > 0 {
            length,_ = conn.Read(buff2)
            total -= uint32(length)
            buff = append(buff,buff2...)
        }
        for i := 0;i < 1024;i++ {
            if buff[i] == 0 { break; }
            fmt.Printf("%c",buff[i])
        }
        fmt.Print("\n")
    }
}

修正:上述代碼存在嚴重問題:
發送1K以上數據會無法正確接收
改進代碼如下:

//tcp.go
package tcp
import "net"
import "fmt"

func ok(bytes []byte) bool {
    return bytes[0] == 111 && bytes[1] == 107 && bytes[2] == 0;
}

func bytes4uint(bytes []byte) uint32 {
    total := uint32(0); 
    for i := 0;i < 4;i++ {
        total <<= 8;
        total += uint32(bytes[i]);
    }
    return total
}

func uint32bytes(n uint32) []byte {
    header := make([]byte,4)
    i := 0
    for n > 0 {
        header[3-i] = byte(n % 256)
        n /= 256
        i++
    }
    return header
}


type TCPSession struct {
    Conn *net.TCPConn
    ToSend chan interface{} //要發送的數據
    Received chan interface{}   //接受到的數據
    Closed bool //是否已經關閉
}

func (s *TCPSession) Init() {
    s.ToSend = make(chan interface{})
    s.Received = make(chan interface{})
    go s.Send()
    go s.Recv()
}

func (s *TCPSession) Send() {
    for {
        if s.Closed {
            return
        }
        buf0 := <- s.ToSend //取出要發送的數據
        buf := buf0.([]byte)

        _,err := s.Conn.Write(buf)  //發送掉   
        //fmt.Println("send,",buf)
        if err != nil {
            s.Closed = true
            return
        }
    }

}

func (s *TCPSession) Recv() {
    for {
        if s.Closed {
            return
        }
        buf := make([]byte,1024)
        _,err := s.Conn.Read(buf)
        if err != nil {
            s.Closed = true
            return
        }
        s.Received <- buf
        //fmt.Println("read,",buf)
        }

}

func (s *TCPSession) SendMessage(bytes []byte) {
    total := len(bytes) / 1024
    if len(bytes) % 1024 != 0 {
        total++
    }
    header := uint32bytes(uint32(total))    //計算條數
    s.ToSend <- header
    //fmt.Println(header)
    for i := 0;i < total-1;i++ {
        buf := bytes[0:1024]    //發送這一段
        bytes = bytes[1024:]
        s.ToSend <- buf
        continue
    }
    //發送最後一段
    if total == 0 {
        return
    }
    buf := bytes[0:]    //發送這一段
    s.ToSend <- buf
}

func (s *TCPSession) ReadMessage() []byte {
    buf0 := <- s.Received
    buf := buf0.([]byte)
    //fmt.Println(buf)
    total := bytes4uint(buf)
    var buff []byte
    if buf[4] != 0 {    //兩份報表被合併
        buff = buf[4:]
        total--
    } else {
        buff = []byte{}     
    }

    for i := uint32(0);i < total;i++ {
        buf0 := <- s.Received
        buf := buf0.([]byte)
        buff = append(buff,buf...)
    }
    return buff
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章