linux socket實現內核態和用戶態通信

最近有一個的需求,需要將一些linux用戶態的命令做成自動化。
比如在用戶態執行lspci命令,判斷獲取的設備中是否有某個型號的pci卡,這就需要linux內核態和用戶態交互。實現的方法是通過linux內核態編程。在內核驅動中通過socket發送一個用戶態請求,server端接收到請求並執行,執行後將結果返回給內核驅動,驅動中判斷結果。

內核態socket編程的過程和用戶態下的socket編程流程一樣,但是接口不同。Kernel提供了一組內核態的socket API,基本上在用戶態的sockt API在內核中都有對應的API。

在net/socket.c中可以看到導出符號:
主要實現:
server端代碼:

#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <ctype.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <stdlib.h>
#include <errno.h>
#include "public.h"
#include <sys/time.h>
#include <sys/ioctl.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <errno.h>


int server_listen_fd;
int server_accept_fd;
int port= 2002;

int get_current_time()
{
    struct timeval stime;
    gettimeofday( &stime, NULL);
    return stime.tv_sec * 1000000 + stime.tv_usec;
}


int excute_cmd(char* cmd, char* result) {
    char buffer[1024];                                                                                                 //定義緩衝區                                                                                             
    FILE* pipe = popen(cmd, "r");  
    if (!pipe)
        return 1;                            

    while (!feof(pipe)) {
        if (fgets(buffer, 1024, pipe)){   
            strcat(result,buffer);
        }
    }
    pclose(pipe);                                                                                                             //關閉管道    
    return 0;                                                                                                                       
}


int server_recv(int fd, cmd_request *request, int len_recv, int timeout)
{   
    int len;
    unsigned int recved = 0;
    unsigned long last_time = get_current_time();
    int temp_time;
    printf("server recv start %d time:%d!\n", len_recv, timeout);
    while (1) {

        temp_time =  get_current_time() - last_time;
        if (get_current_time() - last_time > timeout * 1000) {
            printf("server recv timeout!\n");
            break;
        }

        len = recv(fd, request->reqbuf + recved, len_recv - recved, MSG_DONTWAIT);
        printf("recv buf is %s:\n",request->reqbuf);
        if (len <= 0) {
            printf("server recv error!\n");
            return  -1;
        }
        printf("server has recv %d\n",len);
        recved += len;
        if (recved >= len_recv) {
            printf("recved %d bytes!\n", recved);
            return recved;
        }


    }

    return 0;
}

int server_accept()
{   
    int size;
    int opt = 1;
    int flags;
    struct sockaddr_in server_accept_addr;
    bzero(&server_accept_addr, sizeof(server_accept_addr));
    size = sizeof(server_accept_addr);
    flags = fcntl(server_accept_fd, F_GETFL, 0);
    fcntl(server_accept_fd, F_SETFL, flags | O_NONBLOCK);
    setsockopt(server_accept_fd, IPPROTO_TCP, TCP_NODELAY, (char *)&opt,sizeof(int));

Step1:
    server_accept_fd = accept(server_listen_fd, (struct sockaddr*)&server_accept_addr, &size);

    if (server_accept_fd < 0) {   
        if (errno == EAGAIN) {
            goto Step1;
        }
        perror("error:socket accept1 exited!\n");                 
        exit(1);                 
    } 
}

int server_init(char *ip)
{
    int opt =1;
    int flags;
    struct sockaddr_in server_listen_addr;
    bzero(&server_listen_addr, sizeof(server_listen_addr));

    server_listen_addr.sin_family = AF_INET;     
    server_listen_addr.sin_addr.s_addr = inet_addr(ip);
    server_listen_addr.sin_port = htons(port);
    setsockopt(server_listen_fd, SOL_SOCKET, SO_REUSEADDR,(char *)&opt, sizeof(int));
    if (bind(server_listen_fd, (struct sockaddr*)&server_listen_addr,
        sizeof(server_listen_addr)) == -1) {
        perror("can't to bind");
        exit(1);
    }
    flags = fcntl(server_listen_fd, F_GETFL, 0);
    fcntl(server_listen_fd, F_SETFL, flags | O_NONBLOCK);
    if (listen(server_listen_fd, 10) == -1) {
        perror("can't to bind");
        exit(1);
    }

    return 0;
}
int main(int argc,char *argv[])
{

    int ret;
    int flags;
    int timeout = 1000;


    struct sockaddr_in server_send_addr;
    cmd_request  request;
    cmd_response response;

    //printf("argv is %s\n", argv[1]);
    if (!argv[1]) {
        printf("need ip!\n");
        return -1;
    }


    bzero(&server_send_addr, sizeof(server_send_addr));
    memset(&request, 0 ,sizeof(request));
    memset(&response, 0 ,sizeof(response));


    server_listen_fd = socket(AF_INET,SOCK_STREAM,0);
    if (-1 == server_listen_fd) {
        perror("fail to create socket!");
        exit(1);
    }

    server_init(argv[1]);

    while (1) {
        printf("server socket begin accept:\n");
        server_accept();

        //recv
        ret = server_recv(server_accept_fd, &request, sizeof(cmd_request), timeout);
        if (ret <= 0) {
            printf("server recv error!\n");
        } 


        //ret = recv(server_accept_fd, request.reqbuf, 1024, 0);
        printf("DATA:[%s]\n", request.reqbuf);
        excute_cmd(request.reqbuf,response.rspbuf);
        printf("the result is %s",response.rspbuf); 


        //send
        ret = send(server_accept_fd, response.rspbuf, sizeof(response.rspbuf), 0);
        if (ret <= 0){
            printf("send %d failed!\n",ret);
        }

    }

    close(server_accept_fd);
    return 0;
}

客戶端實現:

#include <linux/module.h>
#include <linux/init.h>

#include <linux/socket.h>
#include <net/sock.h>
#include <linux/in.h>
#include <linux/tcp.h>

#include <linux/in.h>
#include <linux/inet.h>
#include <linux/time.h>

#include "public.h"
int port_id = 2002;
char * dst_ip = "192.168.0.103";
module_param(dst_ip, charp, S_IRUSR);

cmd_request *prequest;
cmd_response *presponse;


int client_init(void);
int plugin_get_current_time()
{
    struct timeval stime;
    do_gettimeofday(&stime);
    return (stime.tv_sec * 1000000 + stime.tv_usec) / 1000;
}

int client_send(struct socket *sock, unsigned char *pbufsend, int len_send)
{
    int len;
    int sended = 0;
    unsigned long last_time = plugin_get_current_time();
    struct kvec vec;
    struct msghdr msg;
    unsigned short timeout = 1000;

    while (1) {
        if (plugin_get_current_time() - last_time > timeout){
            printk("kernel send msg timeout!\n");
            break;
        }
        vec.iov_base = pbufsend + sended;
        vec.iov_len = len_send - sended;
        msg.msg_flags = 0;
        len = kernel_sendmsg(sock, &msg, &vec, 1, len_send - sended);
        if (len < 0){
            if (len == -EWOULDBLOCK){
                printk("kernel send msg would block!\n");
                continue;
            } else {
                printk("kernel send msg failed with < 0!\n");
                return -EINVAL;
            }
        } else if (len == 0){
            printk("kernel send msg failed with = 0!\n");
            return -EINVAL;
        }
        sended += len;
        if (sended >= len_send) {
            printk("kernel send %d bytes!\n", sended);
            return sended;
        }
    }

}

int client_init(void)
{   
    struct socket *sock;
    struct sockaddr_in s_addr;
    int ret = 0;

    sock = (struct socket*)kmalloc(sizeof(struct socket), GFP_KERNEL);
    memset(&s_addr, 0, sizeof(s_addr));
    s_addr.sin_family = AF_INET;
    s_addr.sin_port = htons(port_id);
    s_addr.sin_addr.s_addr = in_aton(dst_ip);

    /*create socket*/
    ret = sock_create_kern(AF_INET, SOCK_STREAM, 0, &sock);
    if (ret) {
        printk("socket create failed\n");
        return ret;
    }

    printk("create socket ok!\n");

    /*connect server*/
    ret = sock->ops->connect(sock, (struct sockaddr*)&s_addr, sizeof(struct sockaddr_in), 0);
    if (ret) {
        printk("socket connect server failed!\n");
        return ret;
    }
    printk("connect server ok!\n"); 

    // set opt
    int opt = 1;
    int flags;
    kernel_setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *)&opt, sizeof(int));
    flags = kernel_sock_ioctl(sock, F_GETFL, 0);
    kernel_sock_ioctl(sock, F_SETFL, flags | O_NONBLOCK);

    /*kmalloc sendbuf*/
    char *sendbuf = NULL;
    sendbuf = kmalloc(1024, GFP_KERNEL);
    memset(sendbuf, 0, 1024);
    strcpy(sendbuf, "lspci");
    printk("the request is %s, size is %d\n", sendbuf, sizeof(sendbuf));
    ret = client_send(sock, sendbuf, 1024);

    if (ret <= 0) {
        printk("client send failed !\n");
        return -EINVAL;
    }

    recvbuf = kmalloc(1024, GFP_KERNEL);
    memset(recvbuf, 0, 1024);
    memset(&msg, 0, sizeof(msg));
    memset(&vec, 0, sizeof(vec));
    vec.iov_base = recvbuf; 
    vec.iov_len=1024;
    int count = 0;
    while (count < 1000) {
        ret = kernel_recvmsg(sock, &msg, &vec, 1, 1024, 0);
        if(ret < 0){
            printk("client:kernel_sendmsg error!\n");
            return ret;
        } else if (ret > 0) {
            printk("recv message %s\n",recvbuf);
            break;
        }
        count ++;
    }

    if (count >= 1000)
        printk("kernel recv msg timeout!\n");

    //判斷結果是否符合預期
    char *expect = "PCI";
    if (strstr(recvbuf, expect) != NULL) {
        printk("lspci test pass!\n");

    }

    return ret;

}



#if 0
int tc_run(skip_tc *pskip_tc)
{
    socket *psock;
    pskip_tc->desc = psock;
    client_init();
    client_recv_and_send(pskip_tc,pskip_tc->request,sizeof(cmd_request),pskip_tc->response,sizeof(cmd_response));
}

#endif

static int __init plugin_init(void)
{
    printk("hello, plugin!\n"); 
    client_init();
    return 0;
}

static void __exit plugin_exit(void)
{
    printk("goodbye,plugin!\n");
}
module_init(plugin_init);
module_exit(plugin_exit);
MODULE_LICENSE("GPL");

使用示例:
服務端:
這裏寫圖片描述
客戶端:
這裏寫圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章