socket多線程下載文件

多線程下載的思路是建立多個線程,同時連接到服務器,分別從文件的不同位置開始下載,然後將各自接收到的數據合併到同一個文件中。

// 服務器端代碼
/*************************************************************************
    > File Name: server.cpp
    > Author: hp
    > Mail: [email protected]
    > Created Time: 2017年01月28日 星期六 19時39分47秒
 ************************************************************************/

#include <iostream>
#include <cstring>
#include <thread>
#include <unistd.h>
#include <fcntl.h>
#include <dirent.h>
#include <error.h>
#include <sys/types.h>  
#include <sys/stat.h>
#include <sys/socket.h>
#include <netinet/in.h> 
#include <arpa/inet.h>
#include <signal.h>

#define PORT            8888        // 偵聽端口地址
#define BACKLOG         20          // 偵聽隊列長度
#define MAXLEN          512         // 消息最大長度
const int MAX_THREADS = 2;          // 最大線程數

using namespace std;

struct message {
    int sc;                     // 套接字
    int flag;                   // 判斷下載文件是否存在
    int infd;                   // 源文件句柄
    size_t start;               // 文件的寫入起始位置
    size_t end;                 // 文件的寫入終止位置
    int filesize;               // 下載文件大小
    char filelist[10000];       // 下載文件列表
    char filename[256];         // 下載文件名
    char file[1024];            // 文件內容
} inf;


class Download_data {
public: 
    void get_filelist(void);            // 獲取下載文件列表
    int get_size(const char *);         // 獲取文件大小
    void my_err(const char *, int );    // 錯誤處理函數
};

void handler_sigint(int );          // 信號處理函數
void *process(void *);              // 服務器對客戶端的處理
void *thread_send(void *);          // 線程發送文件

// 錯誤處理函數
void Download_data::my_err(const char *err_string, int line)
{
    fprintf(stderr, "line: %d ", line);
    perror(err_string); 
    exit(1);
}

// 獲取下載文件列表
void Download_data::get_filelist(void)
{
    DIR *dir;
    struct dirent *ptr;
    char path[MAXLEN];

    // 獲取當前工作路徑
    getcwd(path, MAXLEN);

    // 打開當前工作路徑
    if ((dir = opendir(path)) == NULL) {
        perror("opendir");
    }

    // 循環將當前路徑下的所有文件名連接起來
    while ((ptr = readdir(dir)) != NULL) {
        // 如果是'.'或'..'則跳過
        if (strcmp(ptr->d_name, ".") == 0 || strcmp(ptr->d_name, "..") == 0) {
            continue;
        }
        strcat(inf.filelist, ptr->d_name);
        strcat(inf.filelist, "\n");
        //sprintf(filelist, "%s%s", ptr->d_name, "\n");
    }
    // 關閉目錄流
    closedir(dir);
}

// 獲取文件大小
int Download_data::get_size(const char *file_name) 
{
    struct stat st;
    bzero(&st, sizeof(st));
    stat(file_name, &st);
    return st.st_size;
}

// 信號處理函數
void handler_sigint(int signo)
{
}

/* 服務器對客戶端的處理 */
void *process(void *arg)
{
    int fd;
    int sc = *(int *)arg;
    Download_data dd;

    // 獲取當前工作路徑下的所有文件名併發送至客戶端
    dd.get_filelist();
    if ( send(sc, &inf, sizeof(struct message), 0) < 0 ) {
        dd.my_err("send", __LINE__);
    }

    while (1) {
        // 清空結構體裏的數據
        memset(&inf, 0, sizeof(struct message));

        // 接收來自客戶端的待下載文件名
        if ( recv(sc, inf.filename, sizeof(inf.filename), 0) < 0 ) {
            dd.my_err("recv", __LINE__);
        }  

        // 如果待下載文件存在,則將標誌位置爲1,並獲取該文件的大小
        // 否則將標誌位置爲0
        if (access(inf.filename, 0) == 0) {
            inf.flag = 1;
            inf.filesize = dd.get_size(inf.filename);
        } else {
            inf.flag = 0;
        }

        // 將結構體發至客戶端
        if (send(sc, &inf, sizeof(struct message), 0) < 0) {
            dd.my_err("send", __LINE__);
        }

        // 如果待下載文件不存在,則繼續等待客戶端的再次輸入
        if (inf.flag == 0) {
            continue;
        }

        // 打開待下載文件
        if ( (fd = open(inf.filename, O_RDONLY)) < 0 ) {
            dd.my_err("open",__LINE__);
        }

        size_t file_size = inf.filesize;            // 待下載文件大小
        size_t thread_size = MAX_THREADS;           // 線程數
        size_t percent = file_size / thread_size;   // 每個線程下載的文件大小
        cout << "filesize = " << file_size << "\t percent_blocks = " << percent << endl;

        struct message *blocks = (struct message *)malloc(sizeof(struct message) * thread_size);

        // 循環初始化每個線程的數據
        int i = 0;
        for (; i < thread_size; i++) {
            blocks[i].sc = sc;
            blocks[i].infd = fd;
            blocks[i].start = i * percent;
            blocks[i].end = blocks[i].start + percent;
        }
        blocks[i-1].end = file_size;

        pthread_t thid[thread_size];

        // 創建線程
        for (i = 0; i < thread_size; i++) {
            pthread_create(&thid[i], NULL, thread_send, &blocks[i]);
            sleep(1);
        }

        for (i = 0; i < thread_size; i++) {
            pthread_join(thid[i], NULL);
        }

        free(blocks);       // 釋放空間
        close(fd);          // 關閉文件描述符
        cout << "\nsend successful!" << endl;
    }
    close(sc);              // 關閉套接字
}

// 線程發送文件
void *thread_send(void *arg)
{
    struct message *block = (struct message *)arg;
    size_t count = block->start;
    int bytes_read;

    cout << "\nIn thread " << pthread_self() << "\nstart = " 
         << block->start << ", end = " << block->end << endl;

    // lseek到同樣的位置
    int ret = lseek(block->infd, block->start, SEEK_SET);

    // 循環發送文件
    while (count < block->end) {
        bytes_read = read(block->infd, inf.file, 1);
        if (send(block->sc, inf.file, bytes_read, 0) < 0) {
            cout << "send error !!!" << endl;
            exit(-1);
        }
        count += bytes_read;
        memset(&inf.file, 0, sizeof(inf.file));
    }
    cout << "Thread " << pthread_self() << " exit !!!" << endl;
    pthread_exit(NULL);
}

int main(int argc, char *argv[])  
{  
    int ss, sc;  
    int optval;
    int err;
    pthread_t thid;
    Download_data dd;
    struct sockaddr_in server_addr;     // 服務器地址結構 
    struct sockaddr_in client_addr;     // 客戶端地址結構

    signal(SIGINT, handler_sigint);     // 屏蔽Ctrl+C

    if ( (ss = socket(AF_INET, SOCK_STREAM, 0)) < 0 ) {
        dd.my_err("socket", __LINE__);
    }  

    // 設置該套接字使之可以重新綁定端口
    optval = 1;
    if ( setsockopt(ss, SOL_SOCKET, SO_REUSEADDR, (void *)&optval, sizeof(int)) < 0 ) {
        dd.my_err("setsockopt", __LINE__);
    }

    // 設置服務器地址
    bzero(&server_addr, sizeof(server_addr));   // 清零 
    server_addr.sin_family = AF_INET;           // 協議族 
    server_addr.sin_port = htons(PORT);         // 服務器端口
    server_addr.sin_addr.s_addr = htons(INADDR_ANY);  

    // 綁定地址結構到套接字描述符
    if ( (err = bind(ss, (struct sockaddr *)&server_addr, sizeof(server_addr))) < 0 ) {  
        dd.my_err("bind", __LINE__);
    }  

    // 設置監聽  
    if ( (err = listen(ss, BACKLOG)) < 0) {
        dd.my_err("listen", __LINE__);
    }  

    socklen_t length = sizeof(client_addr); 

    // 服務器端一直運行用以持續爲客戶端提供服務  
    while (1) {  

        if ( (sc = accept(ss, (struct sockaddr *)&client_addr, &length)) < 0) {
            dd.my_err("accept", __LINE__);
        }  
        cout << "新的連接ip: " << inet_ntoa(client_addr.sin_addr) << "\nsocket_fd: " << sc << endl;

        pthread_create(&thid, NULL, process, (void *)&sc);
    }  
    close(ss);          // 關閉套接字

    return 0;  
}  
// 客戶端代碼
/*************************************************************************
    > File Name: client.cpp
    > Author: hp
    > Mail: [email protected]
    > Created Time: 2017年01月28日 星期六 20時30分25秒
 ************************************************************************/

#include <iostream>
#include <cstring>
#include <thread>
#include <unistd.h>
#include <fcntl.h>
#include <time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define PORT            8888        // 偵聽端口地址
#define MAXLEN          512         // 最大消息長度
const int MAX_THREADS = 2;          // 最大線程數

using namespace std;

struct message {
    int cs;                     // 套接字
    int flag;                   // 判斷下載文件是否存在
    int outfd;                  // 目標文件句柄
    size_t start;               // 文件寫入的起始位置
    size_t end;                 // 文件下載的結束位置
    int filesize;               // 下載文件大小
    char filelist[10000];       // 下載文件列表
    char filename[256];         // 下載文件名
    char file[1024];            // 文件內容
} inf;

class Download_data {
public:
    int socket_client(void );                   // 客戶端socket函數
    void process(struct message, int );         // 客戶端對服務器端的處理
    void my_err(const char *, int );            // 錯誤處理函數
};

void *thread_recv(void *arg);       // 線程接收文件

// 錯誤處理函數
void Download_data::my_err(const char *err_string, int line)
{
    fprintf(stderr, "line: %d ", line);
    perror(err_string); 
    exit(1);
}

// 客戶端socket函數
int Download_data::socket_client(void)
{
    int cs;
    Download_data dd;
    struct sockaddr_in server_addr;

    if ( (cs = socket(AF_INET, SOCK_STREAM, 0)) < 0 ) {
        dd.my_err("socket", __LINE__);
    }

    // 設置服務器地址
    bzero(&server_addr, sizeof(server_addr));  
    server_addr.sin_family = AF_INET;                   // internet協議族  
    server_addr.sin_addr.s_addr = htons(INADDR_ANY);    // INADDR_ANY表示自動獲取本機地址  
    server_addr.sin_port = htons(PORT);

    // 服務器的IP地址來自程序的參數  
    if (inet_aton("127.0.0.1", &server_addr.sin_addr) == 0) {  
        dd.my_err("inet_aton", __LINE__);
    }  

    socklen_t server_addr_length = sizeof(server_addr);  

    if ( connect(cs, (struct sockaddr *)&server_addr, server_addr_length) < 0 ) {  
        dd.my_err("connect", __LINE__);
    } 
    return cs; 
}

/* 客戶端對服務器端的處理 */
void Download_data::process(struct message info, int cs)
{
    int fd;
    Download_data dd;
    // 打開待下載文件
    if ( (fd = open(info.filename, O_RDWR|O_CREAT|O_APPEND, 0644)) < 0 ) {
        dd.my_err("open", __LINE__);
    }

    size_t file_size = info.filesize;           // 待下載文件大小
    size_t thread_size = MAX_THREADS;           // 線程數
    size_t percent = file_size / thread_size;   // 每個線程下載文件大小
    cout << "filesize = " << file_size << "\t percent_blocks = " << percent << endl;

    struct message *blocks = (struct message *)malloc(sizeof(struct message) * thread_size);

    int i = 0;
    for (; i < thread_size; i++) {
        blocks[i].cs = cs;
        blocks[i].outfd = fd;
        blocks[i].start = i * percent;
        blocks[i].end = blocks[i].start + percent;
    }
    blocks[i-1].end = file_size;

    pthread_t thid[thread_size];

    // 創建線程
    for (i = 0; i < thread_size; i++) {
        pthread_create(&thid[i], NULL, thread_recv, &blocks[i]);
        sleep(1);
    }
    for (i = 0; i < thread_size; i++) {
        pthread_join(thid[i], NULL);
    }

    free(blocks);       // 釋放空間
    close(fd);          // 關閉文件描述符
}

// 線程接收文件
void *thread_recv(void *arg)
{
    struct message *block = (struct message *)arg;
    size_t count = block->start;
    int bytes_write, length = 0;

    cout << "\nIn thread " << pthread_self() << "\nstart = " 
         << block->start << ", end = " << block->end << endl;

    // lseek到同樣的位置
    lseek(block->outfd, block->start, SEEK_SET);

    // 循環接收文件內容並寫入
    while (count < block->end) {
        recv(block->cs, block->file, 1, 0);
        bytes_write = write(block->outfd, block->file, 1);
        count += bytes_write;
        memset(block->file, 0, sizeof(block->file));
        if (count == block->end) {
            break;
        }
    }

    cout << "Thread " << pthread_self() << " exit !!!" << endl;
    pthread_exit(NULL);
}

int main(int argc, char *argv[])  
{  
    int cs;
    clock_t start, end;
    double duration;
    Download_data dd;

    cs = dd.socket_client();

    // 清空結構體
    memset(&inf, 0, sizeof(struct message));
    // 接收服務器端發送過來的結構體
    if (recv(cs, &inf, sizeof(struct message), 0) < 0) {
        dd.my_err("recv", __LINE__);
    }
    string filelist = inf.filelist;     // 將接收到的下載文件列表賦給filelist

    while (1) {
        cout << "\n======================\n";
        cout << "**** 下載文件列表 ****\n";
        cout << "======================\n\n";
        cout << filelist << endl;

        memset(inf.filename, 0, sizeof(inf.filename));
        cout << "請輸入要下載的文件名: ";
        cin >> inf.filename;

        // 將用戶輸入的文件名發送至服務器
        if ( send(cs, inf.filename, sizeof(inf.filename), 0) < 0 ) {
            dd.my_err("send", __LINE__);
        }
        // 接收服務器端發送過來的結構體
        if (recv(cs, &inf, sizeof(struct message), 0) < 0) {
            dd.my_err("recv", __LINE__);
        }
        // 判斷用戶輸入的待下載文件是否存在
        // 不存在則等待2s後重新輸入
        if (inf.flag == 0) {
            cout << "\n待下載文件名輸入錯誤,請重新輸入!!!" << endl;
            sleep(2);
            continue;
        }

        start = clock();        // 開始計時
        dd.process(inf, cs);    // 處理下載文件
        end = clock();          // 結束計時

        duration = (double)(end - start) / CLOCKS_PER_SEC;     // 計算下載時間


        cout << "\n下載成功!" << endl;
        cout << "耗時: " << duration << " s." << endl;

    }

    close(cs);      // 關閉套接字

    return 0; 
}  
// 打開兩個終端,分別編譯服務器端代碼和客戶端代碼

g++ -o server server.cpp -g -lpthread
g++ -o client client.cpp -g -lpthread

// 打開兩個終端,分別運行服務器和客戶端

./server
./client
發佈了48 篇原創文章 · 獲贊 141 · 訪問量 28萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章