TCP偵聽
作爲支持集羣的數據庫,必定要與多個客戶端交互信息,不可能讓數據庫與所有客戶共享地址空間(雖然這樣性能好),所以需要使用TCP協議進行交互數據,(UDP協議不可靠。。。棄用),C語言的TCP庫其實還好,但是對於高併發和並行的處理不如Go,而且併發鎖機制比較難寫,所以使用Go寫了服務器和客戶端調用C的庫,目前版本沒有什麼身份驗證,之後會加上。
代碼實現
//server.go
package main
// #cgo LDFLAGS: -L ./lib -lmonkeyS
// #include "./lib/core.h"
// #include <stdlib.h>
import "C"
import (
"unsafe"
_"fmt"
"net"
"strings"
)
func main() {
str := []byte("monkey")
str = append(str,0)
C.CreateDB((*C.char)(unsafe.Pointer(&str[0]))) //創建基礎數據庫
servicePort := ":1517"
tcpAddr,err := net.ResolveTCPAddr("tcp4",servicePort)
if err != nil {
panic(err)
}
l,err := net.ListenTCP("tcp",tcpAddr) //偵聽TCP
if err != nil {
panic(err)
}
for{
conn,err := l.Accept()
if err != nil {
panic(err)
}
go Handler(conn)
}
}
func Handler(conn net.Conn) {
str := []byte("monkey") //環境變量-當前數據庫
db := C.SwitchDB((*C.char)(unsafe.Pointer(&str[0])))
for {
buff := []byte{}
buf := make([]byte,1024)
length,err := conn.Read(buf)
total := uint32(0); //前4個字節保存消息長度
for i := 0;i < 4;i++ {
total <<= 8;
total += uint32(buf[i]);
}
//fmt.Println("Message length:",total)
buff = append(buff,buf[4:]...)
total -= uint32(length)
for total > 0 {
length,err = conn.Read(buf)
total -= uint32(length)
buff = append(buff,buf...)
}
if err != nil {
conn.Close()
break
}
TranslateMessage(conn,&db,buff) //解析消息
}
}
func TranslateMessage(conn net.Conn,db **C.Database,message []byte) {
command := string(message)
params := strings.Split(command," ")
//fmt.Println(params)
response := []byte{}
if params[0] == "set" {
r := C.Set(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))),(unsafe.Pointer(&([]byte(params[2]))[0])))
for i := 0;;i++ {
response = append(response,byte(r.msg[i]))
if response[i] == 0 { break; }
}
}else if params[0] == "get" {
r := C.Get(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
// for i := 0;;i++ {
// response = append(response,byte(r.msg[i]))
// if response[i] == 0 { break; }
// }
if(int(r.code) == 0) {
for i := 0;;i++ {
response = append(response,byte(*(*C.char)(unsafe.Pointer((uintptr(r.pData)+uintptr(i))))))
if response[i] == 0 { break; }
}
}else {
// for i := 0;;i++ {
// response = append(response,byte(r.msg[i]))
// if response[i] == 0 { break; }
// }
}
}else if params[0] == "delete" || params[0] == "remove" {
r := C.Delete(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
for i := 0;;i++ {
response = append(response,byte(r.msg[i]))
if response[i] == 0 { break; }
}
}else if params[0] == "createdb" {
d := C.CreateDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
if d != nil {
*db = d
response = []byte("Already exist,switched\n")
}else {
response = []byte("Created\n")
}
}else if params[0] == "switchdb" {
d := C.SwitchDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
if d != nil {
*db = d
response = []byte("ok\n")
}else {
response = []byte("fail\n")
}
}else if params[0] == "dropdb" {
*db = C.DropDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
}else if strings.EqualFold("listdb",params[0]) {
r := C.ListDB()
for i := 0;i < 1024;i++ {
b := byte(*(*C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(r))+uintptr(i))))
response = append(response,b)
if(b == 0){ break; }
}
C.free(unsafe.Pointer(r))
}else {
//fmt.Println("unkown command:",params[0])
}
total := len(response) + 4
header := make([]byte,4)
i := 0
for total > 0 {
header[3-i] = byte(total % 256)
total /= 256
i++
}
response = append(header,response...)
conn.Write(response)
}
//Client.go
package main
import "net"
import "fmt"
func main() {
tcpAddr, err := net.ResolveTCPAddr("tcp4", "127.0.0.1:1517")
if err != nil {
panic(err)
}
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
panic(err)
}
for {
buf1 := ""
buf2 := ""
buf3 := ""
buf := ""
fmt.Print("monkey>")
fmt.Scanf("%s",&buf1)
if buf1 == "set" {
fmt.Scanf("%s",&buf2)
fmt.Scanf("%s",&buf3)
buf = buf1 + " " + buf2 + " " + buf3
}else if buf1 == "get"{
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "remove" || buf1 == "delete" {
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "createdb"{
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "switchdb"{
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "dropdb"{
fmt.Scanf("%s",&buf2)
buf = buf1 + " " + buf2
}else if buf1 == "listdb"{
buf = buf1 + " "
}else if buf1 == "exit"{
fmt.Println("Bye!")
break;
}
total := uint32(0)
total = uint32(len(buf) + 4)
header := make([]byte,4)
i := 0
for total > 0 {
header[3-i] = byte(total % 256)
total /= 256
i++
}
conn.Write(append(header,([]byte(buf))...))
buff := []byte{}
buff2 := make([]byte,1024)
length,_ := conn.Read(buff2)
total = uint32(0); //前4個字節保存消息長度
for i := 0;i < 4;i++ {
total <<= 8;
total += uint32(buff2[i]);
}
buff = append(buff,buff2[4:]...)
total -= uint32(length)
for total > 0 {
length,_ = conn.Read(buff2)
total -= uint32(length)
buff = append(buff,buff2...)
}
for i := 0;i < 1024;i++ {
if buff[i] == 0 { break; }
fmt.Printf("%c",buff[i])
}
fmt.Print("\n")
}
}
修正:上述代碼存在嚴重問題:
發送1K以上數據會無法正確接收
改進代碼如下:
//tcp.go
package tcp
import "net"
import "fmt"
func ok(bytes []byte) bool {
return bytes[0] == 111 && bytes[1] == 107 && bytes[2] == 0;
}
func bytes4uint(bytes []byte) uint32 {
total := uint32(0);
for i := 0;i < 4;i++ {
total <<= 8;
total += uint32(bytes[i]);
}
return total
}
func uint32bytes(n uint32) []byte {
header := make([]byte,4)
i := 0
for n > 0 {
header[3-i] = byte(n % 256)
n /= 256
i++
}
return header
}
type TCPSession struct {
Conn *net.TCPConn
ToSend chan interface{} //要發送的數據
Received chan interface{} //接受到的數據
Closed bool //是否已經關閉
}
func (s *TCPSession) Init() {
s.ToSend = make(chan interface{})
s.Received = make(chan interface{})
go s.Send()
go s.Recv()
}
func (s *TCPSession) Send() {
for {
if s.Closed {
return
}
buf0 := <- s.ToSend //取出要發送的數據
buf := buf0.([]byte)
_,err := s.Conn.Write(buf) //發送掉
//fmt.Println("send,",buf)
if err != nil {
s.Closed = true
return
}
}
}
func (s *TCPSession) Recv() {
for {
if s.Closed {
return
}
buf := make([]byte,1024)
_,err := s.Conn.Read(buf)
if err != nil {
s.Closed = true
return
}
s.Received <- buf
//fmt.Println("read,",buf)
}
}
func (s *TCPSession) SendMessage(bytes []byte) {
total := len(bytes) / 1024
if len(bytes) % 1024 != 0 {
total++
}
header := uint32bytes(uint32(total)) //計算條數
s.ToSend <- header
//fmt.Println(header)
for i := 0;i < total-1;i++ {
buf := bytes[0:1024] //發送這一段
bytes = bytes[1024:]
s.ToSend <- buf
continue
}
//發送最後一段
if total == 0 {
return
}
buf := bytes[0:] //發送這一段
s.ToSend <- buf
}
func (s *TCPSession) ReadMessage() []byte {
buf0 := <- s.Received
buf := buf0.([]byte)
//fmt.Println(buf)
total := bytes4uint(buf)
var buff []byte
if buf[4] != 0 { //兩份報表被合併
buff = buf[4:]
total--
} else {
buff = []byte{}
}
for i := uint32(0);i < total;i++ {
buf0 := <- s.Received
buf := buf0.([]byte)
buff = append(buff,buf...)
}
return buff
}