socket多线程下载文件

多线程下载的思路是建立多个线程,同时连接到服务器,分别从文件的不同位置开始下载,然后将各自接收到的数据合并到同一个文件中。

// 服务器端代码
/*************************************************************************
    > File Name: server.cpp
    > Author: hp
    > Mail: [email protected]
    > Created Time: 2017年01月28日 星期六 19时39分47秒
 ************************************************************************/

#include <iostream>
#include <cstring>
#include <thread>
#include <unistd.h>
#include <fcntl.h>
#include <dirent.h>
#include <error.h>
#include <sys/types.h>  
#include <sys/stat.h>
#include <sys/socket.h>
#include <netinet/in.h> 
#include <arpa/inet.h>
#include <signal.h>

#define PORT            8888        // 侦听端口地址
#define BACKLOG         20          // 侦听队列长度
#define MAXLEN          512         // 消息最大长度
const int MAX_THREADS = 2;          // 最大线程数

using namespace std;

struct message {
    int sc;                     // 套接字
    int flag;                   // 判断下载文件是否存在
    int infd;                   // 源文件句柄
    size_t start;               // 文件的写入起始位置
    size_t end;                 // 文件的写入终止位置
    int filesize;               // 下载文件大小
    char filelist[10000];       // 下载文件列表
    char filename[256];         // 下载文件名
    char file[1024];            // 文件内容
} inf;


class Download_data {
public: 
    void get_filelist(void);            // 获取下载文件列表
    int get_size(const char *);         // 获取文件大小
    void my_err(const char *, int );    // 错误处理函数
};

void handler_sigint(int );          // 信号处理函数
void *process(void *);              // 服务器对客户端的处理
void *thread_send(void *);          // 线程发送文件

// 错误处理函数
void Download_data::my_err(const char *err_string, int line)
{
    fprintf(stderr, "line: %d ", line);
    perror(err_string); 
    exit(1);
}

// 获取下载文件列表
void Download_data::get_filelist(void)
{
    DIR *dir;
    struct dirent *ptr;
    char path[MAXLEN];

    // 获取当前工作路径
    getcwd(path, MAXLEN);

    // 打开当前工作路径
    if ((dir = opendir(path)) == NULL) {
        perror("opendir");
    }

    // 循环将当前路径下的所有文件名连接起来
    while ((ptr = readdir(dir)) != NULL) {
        // 如果是'.'或'..'则跳过
        if (strcmp(ptr->d_name, ".") == 0 || strcmp(ptr->d_name, "..") == 0) {
            continue;
        }
        strcat(inf.filelist, ptr->d_name);
        strcat(inf.filelist, "\n");
        //sprintf(filelist, "%s%s", ptr->d_name, "\n");
    }
    // 关闭目录流
    closedir(dir);
}

// 获取文件大小
int Download_data::get_size(const char *file_name) 
{
    struct stat st;
    bzero(&st, sizeof(st));
    stat(file_name, &st);
    return st.st_size;
}

// 信号处理函数
void handler_sigint(int signo)
{
}

/* 服务器对客户端的处理 */
void *process(void *arg)
{
    int fd;
    int sc = *(int *)arg;
    Download_data dd;

    // 获取当前工作路径下的所有文件名并发送至客户端
    dd.get_filelist();
    if ( send(sc, &inf, sizeof(struct message), 0) < 0 ) {
        dd.my_err("send", __LINE__);
    }

    while (1) {
        // 清空结构体里的数据
        memset(&inf, 0, sizeof(struct message));

        // 接收来自客户端的待下载文件名
        if ( recv(sc, inf.filename, sizeof(inf.filename), 0) < 0 ) {
            dd.my_err("recv", __LINE__);
        }  

        // 如果待下载文件存在,则将标志位置为1,并获取该文件的大小
        // 否则将标志位置为0
        if (access(inf.filename, 0) == 0) {
            inf.flag = 1;
            inf.filesize = dd.get_size(inf.filename);
        } else {
            inf.flag = 0;
        }

        // 将结构体发至客户端
        if (send(sc, &inf, sizeof(struct message), 0) < 0) {
            dd.my_err("send", __LINE__);
        }

        // 如果待下载文件不存在,则继续等待客户端的再次输入
        if (inf.flag == 0) {
            continue;
        }

        // 打开待下载文件
        if ( (fd = open(inf.filename, O_RDONLY)) < 0 ) {
            dd.my_err("open",__LINE__);
        }

        size_t file_size = inf.filesize;            // 待下载文件大小
        size_t thread_size = MAX_THREADS;           // 线程数
        size_t percent = file_size / thread_size;   // 每个线程下载的文件大小
        cout << "filesize = " << file_size << "\t percent_blocks = " << percent << endl;

        struct message *blocks = (struct message *)malloc(sizeof(struct message) * thread_size);

        // 循环初始化每个线程的数据
        int i = 0;
        for (; i < thread_size; i++) {
            blocks[i].sc = sc;
            blocks[i].infd = fd;
            blocks[i].start = i * percent;
            blocks[i].end = blocks[i].start + percent;
        }
        blocks[i-1].end = file_size;

        pthread_t thid[thread_size];

        // 创建线程
        for (i = 0; i < thread_size; i++) {
            pthread_create(&thid[i], NULL, thread_send, &blocks[i]);
            sleep(1);
        }

        for (i = 0; i < thread_size; i++) {
            pthread_join(thid[i], NULL);
        }

        free(blocks);       // 释放空间
        close(fd);          // 关闭文件描述符
        cout << "\nsend successful!" << endl;
    }
    close(sc);              // 关闭套接字
}

// 线程发送文件
void *thread_send(void *arg)
{
    struct message *block = (struct message *)arg;
    size_t count = block->start;
    int bytes_read;

    cout << "\nIn thread " << pthread_self() << "\nstart = " 
         << block->start << ", end = " << block->end << endl;

    // lseek到同样的位置
    int ret = lseek(block->infd, block->start, SEEK_SET);

    // 循环发送文件
    while (count < block->end) {
        bytes_read = read(block->infd, inf.file, 1);
        if (send(block->sc, inf.file, bytes_read, 0) < 0) {
            cout << "send error !!!" << endl;
            exit(-1);
        }
        count += bytes_read;
        memset(&inf.file, 0, sizeof(inf.file));
    }
    cout << "Thread " << pthread_self() << " exit !!!" << endl;
    pthread_exit(NULL);
}

int main(int argc, char *argv[])  
{  
    int ss, sc;  
    int optval;
    int err;
    pthread_t thid;
    Download_data dd;
    struct sockaddr_in server_addr;     // 服务器地址结构 
    struct sockaddr_in client_addr;     // 客户端地址结构

    signal(SIGINT, handler_sigint);     // 屏蔽Ctrl+C

    if ( (ss = socket(AF_INET, SOCK_STREAM, 0)) < 0 ) {
        dd.my_err("socket", __LINE__);
    }  

    // 设置该套接字使之可以重新绑定端口
    optval = 1;
    if ( setsockopt(ss, SOL_SOCKET, SO_REUSEADDR, (void *)&optval, sizeof(int)) < 0 ) {
        dd.my_err("setsockopt", __LINE__);
    }

    // 设置服务器地址
    bzero(&server_addr, sizeof(server_addr));   // 清零 
    server_addr.sin_family = AF_INET;           // 协议族 
    server_addr.sin_port = htons(PORT);         // 服务器端口
    server_addr.sin_addr.s_addr = htons(INADDR_ANY);  

    // 绑定地址结构到套接字描述符
    if ( (err = bind(ss, (struct sockaddr *)&server_addr, sizeof(server_addr))) < 0 ) {  
        dd.my_err("bind", __LINE__);
    }  

    // 设置监听  
    if ( (err = listen(ss, BACKLOG)) < 0) {
        dd.my_err("listen", __LINE__);
    }  

    socklen_t length = sizeof(client_addr); 

    // 服务器端一直运行用以持续为客户端提供服务  
    while (1) {  

        if ( (sc = accept(ss, (struct sockaddr *)&client_addr, &length)) < 0) {
            dd.my_err("accept", __LINE__);
        }  
        cout << "新的连接ip: " << inet_ntoa(client_addr.sin_addr) << "\nsocket_fd: " << sc << endl;

        pthread_create(&thid, NULL, process, (void *)&sc);
    }  
    close(ss);          // 关闭套接字

    return 0;  
}  
// 客户端代码
/*************************************************************************
    > File Name: client.cpp
    > Author: hp
    > Mail: [email protected]
    > Created Time: 2017年01月28日 星期六 20时30分25秒
 ************************************************************************/

#include <iostream>
#include <cstring>
#include <thread>
#include <unistd.h>
#include <fcntl.h>
#include <time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define PORT            8888        // 侦听端口地址
#define MAXLEN          512         // 最大消息长度
const int MAX_THREADS = 2;          // 最大线程数

using namespace std;

struct message {
    int cs;                     // 套接字
    int flag;                   // 判断下载文件是否存在
    int outfd;                  // 目标文件句柄
    size_t start;               // 文件写入的起始位置
    size_t end;                 // 文件下载的结束位置
    int filesize;               // 下载文件大小
    char filelist[10000];       // 下载文件列表
    char filename[256];         // 下载文件名
    char file[1024];            // 文件内容
} inf;

class Download_data {
public:
    int socket_client(void );                   // 客户端socket函数
    void process(struct message, int );         // 客户端对服务器端的处理
    void my_err(const char *, int );            // 错误处理函数
};

void *thread_recv(void *arg);       // 线程接收文件

// 错误处理函数
void Download_data::my_err(const char *err_string, int line)
{
    fprintf(stderr, "line: %d ", line);
    perror(err_string); 
    exit(1);
}

// 客户端socket函数
int Download_data::socket_client(void)
{
    int cs;
    Download_data dd;
    struct sockaddr_in server_addr;

    if ( (cs = socket(AF_INET, SOCK_STREAM, 0)) < 0 ) {
        dd.my_err("socket", __LINE__);
    }

    // 设置服务器地址
    bzero(&server_addr, sizeof(server_addr));  
    server_addr.sin_family = AF_INET;                   // internet协议族  
    server_addr.sin_addr.s_addr = htons(INADDR_ANY);    // INADDR_ANY表示自动获取本机地址  
    server_addr.sin_port = htons(PORT);

    // 服务器的IP地址来自程序的参数  
    if (inet_aton("127.0.0.1", &server_addr.sin_addr) == 0) {  
        dd.my_err("inet_aton", __LINE__);
    }  

    socklen_t server_addr_length = sizeof(server_addr);  

    if ( connect(cs, (struct sockaddr *)&server_addr, server_addr_length) < 0 ) {  
        dd.my_err("connect", __LINE__);
    } 
    return cs; 
}

/* 客户端对服务器端的处理 */
void Download_data::process(struct message info, int cs)
{
    int fd;
    Download_data dd;
    // 打开待下载文件
    if ( (fd = open(info.filename, O_RDWR|O_CREAT|O_APPEND, 0644)) < 0 ) {
        dd.my_err("open", __LINE__);
    }

    size_t file_size = info.filesize;           // 待下载文件大小
    size_t thread_size = MAX_THREADS;           // 线程数
    size_t percent = file_size / thread_size;   // 每个线程下载文件大小
    cout << "filesize = " << file_size << "\t percent_blocks = " << percent << endl;

    struct message *blocks = (struct message *)malloc(sizeof(struct message) * thread_size);

    int i = 0;
    for (; i < thread_size; i++) {
        blocks[i].cs = cs;
        blocks[i].outfd = fd;
        blocks[i].start = i * percent;
        blocks[i].end = blocks[i].start + percent;
    }
    blocks[i-1].end = file_size;

    pthread_t thid[thread_size];

    // 创建线程
    for (i = 0; i < thread_size; i++) {
        pthread_create(&thid[i], NULL, thread_recv, &blocks[i]);
        sleep(1);
    }
    for (i = 0; i < thread_size; i++) {
        pthread_join(thid[i], NULL);
    }

    free(blocks);       // 释放空间
    close(fd);          // 关闭文件描述符
}

// 线程接收文件
void *thread_recv(void *arg)
{
    struct message *block = (struct message *)arg;
    size_t count = block->start;
    int bytes_write, length = 0;

    cout << "\nIn thread " << pthread_self() << "\nstart = " 
         << block->start << ", end = " << block->end << endl;

    // lseek到同样的位置
    lseek(block->outfd, block->start, SEEK_SET);

    // 循环接收文件内容并写入
    while (count < block->end) {
        recv(block->cs, block->file, 1, 0);
        bytes_write = write(block->outfd, block->file, 1);
        count += bytes_write;
        memset(block->file, 0, sizeof(block->file));
        if (count == block->end) {
            break;
        }
    }

    cout << "Thread " << pthread_self() << " exit !!!" << endl;
    pthread_exit(NULL);
}

int main(int argc, char *argv[])  
{  
    int cs;
    clock_t start, end;
    double duration;
    Download_data dd;

    cs = dd.socket_client();

    // 清空结构体
    memset(&inf, 0, sizeof(struct message));
    // 接收服务器端发送过来的结构体
    if (recv(cs, &inf, sizeof(struct message), 0) < 0) {
        dd.my_err("recv", __LINE__);
    }
    string filelist = inf.filelist;     // 将接收到的下载文件列表赋给filelist

    while (1) {
        cout << "\n======================\n";
        cout << "**** 下载文件列表 ****\n";
        cout << "======================\n\n";
        cout << filelist << endl;

        memset(inf.filename, 0, sizeof(inf.filename));
        cout << "请输入要下载的文件名: ";
        cin >> inf.filename;

        // 将用户输入的文件名发送至服务器
        if ( send(cs, inf.filename, sizeof(inf.filename), 0) < 0 ) {
            dd.my_err("send", __LINE__);
        }
        // 接收服务器端发送过来的结构体
        if (recv(cs, &inf, sizeof(struct message), 0) < 0) {
            dd.my_err("recv", __LINE__);
        }
        // 判断用户输入的待下载文件是否存在
        // 不存在则等待2s后重新输入
        if (inf.flag == 0) {
            cout << "\n待下载文件名输入错误,请重新输入!!!" << endl;
            sleep(2);
            continue;
        }

        start = clock();        // 开始计时
        dd.process(inf, cs);    // 处理下载文件
        end = clock();          // 结束计时

        duration = (double)(end - start) / CLOCKS_PER_SEC;     // 计算下载时间


        cout << "\n下载成功!" << endl;
        cout << "耗时: " << duration << " s." << endl;

    }

    close(cs);      // 关闭套接字

    return 0; 
}  
// 打开两个终端,分别编译服务器端代码和客户端代码

g++ -o server server.cpp -g -lpthread
g++ -o client client.cpp -g -lpthread

// 打开两个终端,分别运行服务器和客户端

./server
./client
发布了48 篇原创文章 · 获赞 141 · 访问量 28万+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章