用Go寫一個內網穿透工具

系統架構

系統分爲兩個部分,client 和 server,client運行在內網服務器中,server運行在公網服務器中,當我們想訪問內網中的服務,我們通過公網服務器做一箇中繼。

下面是展示我靈魂畫手的時刻了

user發送請求給 server,server和client建立連接,將請求發給client,client再將請求發給本地程序處理(內網中),然後本地程序將處理結果返回給client,client將結果返回給server,server再將結果返回給用戶,這樣用戶就訪問到了內網中的程序了。

代碼流程

  1. server端監聽兩個端口,一個用來和user通信,一個和client通信
  2. client啓動時連接server端,並啓動一個端口監聽本地某程序
  3. 當User連接到server端口,將User請求內容發給client
  4. client將從server收到的請求發給本地程序
  5. client將從本地程序收到的內容發給server
  6. server將從client收到的內容發給User即可

  1. 當Server與client沒有消息通信,連接會斷開
  2. client斷開後,再啓動會連接不到Server
  3. Server端會因爲client斷開而引發panic

爲了解決這種坑點,加入了心跳包機制,通過5s發送一次心跳包,保持client與server的連接,同時建立一個重連通道,監聽該通道,如果當Client被斷開後,則往重連通道放一個值,告訴Server端,等待新的Client連接,而避免引發Panic

代碼

更詳細的我就不說了,直接看代碼,代碼裏面有詳細的註釋

代碼倉庫地址: https://github.com/pibigstar/go-proxy

Server端

運行在具有公網IP地址的服務器端

package main
import (
	"flag"
	"fmt"
	"io"
	"net"
	"runtime"
	"strings"
	"time"
)

var (
	localPort  int
	remotePort int
)

func init() {
	flag.IntVar(&localPort, "l", 5200, "the user link port")
	flag.IntVar(&remotePort, "r", 3333, "client listen port")
}

type client struct {
	conn net.Conn
	// 數據傳輸通道
	read  chan []byte
	write chan []byte
	// 異常退出通道
	exit chan error
	// 重連通道
	reConn chan bool
}

// 從Client端讀取數據
func (c *client) Read() {
	// 如果10秒鐘內沒有消息傳輸,則Read函數會返回一個timeout的錯誤
	_ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
	for {
		data := make([]byte, 10240)
		n, err := c.conn.Read(data)
		if err != nil && err != io.EOF {
			if strings.Contains(err.Error(), "timeout") {
				// 設置讀取時間爲3秒,3秒後若讀取不到, 則err會拋出timeout,然後發送心跳
				_ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
				c.conn.Write([]byte("pi"))
				continue
			}
			fmt.Println("讀取出現錯誤...")
			c.exit <- err
		}

		// 收到心跳包,則跳過
		if data[0] == 'p' && data[1] == 'i' {
			fmt.Println("server收到心跳包")
			continue
		}
		c.read <- data[:n]
	}
}

// 將數據寫入到Client端
func (c *client) Write() {
	for {
		select {
		case data := <-c.write:
			_, err := c.conn.Write(data)
			if err != nil && err != io.EOF {
				c.exit <- err
			}
		}
	}
}

type user struct {
	conn net.Conn
	// 數據傳輸通道
	read  chan []byte
	write chan []byte
	// 異常退出通道
	exit chan error
}

// 從User端讀取數據
func (u *user) Read() {
	_ = u.conn.SetReadDeadline(time.Now().Add(time.Second * 200))
	for {
		data := make([]byte, 10240)
		n, err := u.conn.Read(data)
		if err != nil && err != io.EOF {
			u.exit <- err
		}
		u.read <- data[:n]
	}
}

// 將數據寫給User端
func (u *user) Write() {
	for {
		select {
		case data := <-u.write:
			_, err := u.conn.Write(data)
			if err != nil && err != io.EOF {
				u.exit <- err
			}
		}
	}
}

func main() {
	flag.Parse()

	defer func() {
		err := recover()
		if err != nil {
			fmt.Println(err)
		}
	}()

	clientListener, err := net.Listen("tcp", fmt.Sprintf(":%d", remotePort))
	if err != nil {
		panic(err)
	}
	fmt.Printf("監聽:%d端口, 等待client連接... \n", remotePort)
	// 監聽User來連接
	userListener, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort))
	if err != nil {
		panic(err)
	}
	fmt.Printf("監聽:%d端口, 等待user連接.... \n", localPort)

	for {
		// 有Client來連接了
		clientConn, err := clientListener.Accept()
		if err != nil {
			panic(err)
		}

		fmt.Printf("有Client連接: %s \n", clientConn.RemoteAddr())

		client := &client{
			conn:   clientConn,
			read:   make(chan []byte),
			write:  make(chan []byte),
			exit:   make(chan error),
			reConn: make(chan bool),
		}

		userConnChan := make(chan net.Conn)
		go AcceptUserConn(userListener, userConnChan)

		go HandleClient(client, userConnChan)

		<-client.reConn
		fmt.Println("重新等待新的client連接..")
	}
}

func HandleClient(client *client, userConnChan chan net.Conn) {

	go client.Read()
	go client.Write()

	for {
		select {
		case err := <-client.exit:
			fmt.Printf("client出現錯誤, 開始重試, err: %s \n", err.Error())
			client.reConn <- true
			runtime.Goexit()

		case userConn := <-userConnChan:
			user := &user{
				conn:  userConn,
				read:  make(chan []byte),
				write: make(chan []byte),
				exit:  make(chan error),
			}
			go user.Read()
			go user.Write()

			go handle(client, user)
		}
	}
}

// 將兩個Socket通道鏈接
// 1. 將從user收到的信息發給client
// 2. 將從client收到信息發給user
func handle(client *client, user *user) {
	for {
		select {
		case userRecv := <-user.read:
			// 收到從user發來的信息
			client.write <- userRecv
		case clientRecv := <-client.read:
			// 收到從client發來的信息
			user.write <- clientRecv

		case err := <-client.exit:
			fmt.Println("client出現錯誤, 關閉連接", err.Error())
			_ = client.conn.Close()
			_ = user.conn.Close()
			client.reConn <- true
			// 結束當前goroutine
			runtime.Goexit()

		case err := <-user.exit:
			fmt.Println("user出現錯誤,關閉連接", err.Error())
			_ = user.conn.Close()
		}
	}
}

// 等待user連接
func AcceptUserConn(userListener net.Listener, connChan chan net.Conn) {
	userConn, err := userListener.Accept()
	if err != nil {
		panic(err)
	}
	fmt.Printf("user connect: %s \n", userConn.RemoteAddr())
	connChan <- userConn
}

Client端

運行在需要內網穿透的客戶端中

package main

import (
	"flag"
	"fmt"
	"io"
	"net"
	"runtime"
	"strings"
	"time"
)

var (
	host       string
	localPort  int
	remotePort int
)

func init() {
	flag.StringVar(&host, "h", "127.0.0.1", "remote server ip")
	flag.IntVar(&localPort, "l", 8080, "the local port")
	flag.IntVar(&remotePort, "r", 3333, "remote server port")
}

type server struct {
	conn net.Conn
	// 數據傳輸通道
	read  chan []byte
	write chan []byte
	// 異常退出通道
	exit chan error
	// 重連通道
	reConn chan bool
}

// 從Server端讀取數據
func (s *server) Read() {
	// 如果10秒鐘內沒有消息傳輸,則Read函數會返回一個timeout的錯誤
	_ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
	for {
		data := make([]byte, 10240)
		n, err := s.conn.Read(data)
		if err != nil && err != io.EOF {
			// 讀取超時,發送一個心跳包過去
			if strings.Contains(err.Error(), "timeout") {
				// 3秒發一次心跳
				_ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
				s.conn.Write([]byte("pi"))
				continue
			}
			fmt.Println("從server讀取數據失敗, ", err.Error())
			s.exit <- err
			runtime.Goexit()
		}

		// 如果收到心跳包, 則跳過
		if data[0] == 'p' && data[1] == 'i' {
			fmt.Println("client收到心跳包")
			continue
		}
		s.read <- data[:n]
	}
}

// 將數據寫入到Server端
func (s *server) Write() {
	for {
		select {
		case data := <-s.write:
			_, err := s.conn.Write(data)
			if err != nil && err != io.EOF {
				s.exit <- err
			}
		}
	}
}

type local struct {
	conn net.Conn
	// 數據傳輸通道
	read  chan []byte
	write chan []byte
	// 有異常退出通道
	exit chan error
}

func (l *local) Read() {

	for {
		data := make([]byte, 10240)
		n, err := l.conn.Read(data)
		if err != nil {
			l.exit <- err
		}
		l.read <- data[:n]
	}
}

func (l *local) Write() {
	for {
		select {
		case data := <-l.write:
			_, err := l.conn.Write(data)
			if err != nil {
				l.exit <- err
			}
		}
	}
}

func main() {
	flag.Parse()

	target := net.JoinHostPort(host, fmt.Sprintf("%d", remotePort))
	for {
		serverConn, err := net.Dial("tcp", target)
		if err != nil {
			panic(err)
		}

		fmt.Printf("已連接server: %s \n", serverConn.RemoteAddr())
		server := &server{
			conn:   serverConn,
			read:   make(chan []byte),
			write:  make(chan []byte),
			exit:   make(chan error),
			reConn: make(chan bool),
		}

		go server.Read()
		go server.Write()

		go handle(server)

		<-server.reConn
		_ = server.conn.Close()
	}

}

func handle(server *server) {
	// 等待server端發來的信息,也就是說user來請求server了
	data := <-server.read

	localConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
	if err != nil {
		panic(err)
	}

	local := &local{
		conn:  localConn,
		read:  make(chan []byte),
		write: make(chan []byte),
		exit:  make(chan error),
	}

	go local.Read()
	go local.Write()

	local.write <- data

	for {
		select {
		case data := <-server.read:
			local.write <- data

		case data := <-local.read:
			server.write <- data

		case err := <-server.exit:
			fmt.Printf("server have err: %s", err.Error())
			_ = server.conn.Close()
			_ = local.conn.Close()
			server.reConn <- true

		case err := <-local.exit:
			fmt.Printf("server have err: %s", err.Error())
			_ = local.conn.Close()
		}
	}
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章