ncnn之五:參數和模型文件結構

net.param

7767517 # magic number
158 183 # [layer count] [blob count]
# [layer type] [layer name] [input count] [output count] [input blobs] [output blobs] [layer specific params]
Input            data                             0 1 data
Convolution      conv1                            1 1 data conv1 0=64 1=3 11=3 3=2 13=2 4=0 14=0 5=1 6=1728
ReLU             relu_conv1                       1 1 conv1 relu_conv1
Convolution      conv2                            1 1 relu_conv1 conv2 0=64 1=3 11=3 3=2 13=2 4=0 14=0 5=1 6=36864

load_param代碼:

int Net::load_param(FILE* fp)
{
    int magic = 0;
    fscanf(fp, "%d", &magic);
    if (magic != 7767517)
    {
        fprintf(stderr, "param is too old, please regenerate\n");
        return -1;
    }

    // parse
    int layer_count = 0;
    int blob_count = 0;
    fscanf(fp, "%d %d", &layer_count, &blob_count);

    layers.resize(layer_count);
    blobs.resize(blob_count);

    ParamDict pd;

    int layer_index = 0;
    int blob_index = 0;
    while (!feof(fp))
    {
        int nscan = 0;
        
        char layer_type[257];
        char layer_name[257];
        int bottom_count = 0;
        int top_count = 0;
        nscan = fscanf(fp, "%256s %256s %d %d", layer_type, layer_name, &bottom_count, &top_count);
        if (nscan != 4)
        {
            continue;
        }

        Layer* layer = create_layer(layer_type);
        if (!layer)
        {
            layer = create_custom_layer(layer_type);
        }
        if (!layer)
        {
            fprintf(stderr, "layer %s not exists or registered\n", layer_type);
            clear();
            return -1;
        }

        layer->type = std::string(layer_type);
        layer->name = std::string(layer_name);
//         fprintf(stderr, "new layer %d %s\n", layer_index, layer_name);

        layer->bottoms.resize(bottom_count);
        for (int i=0; i<bottom_count; i++)
        {
            char bottom_name[257];
            nscan = fscanf(fp, "%256s", bottom_name);
            if (nscan != 1)
            {
                continue;
            }

            int bottom_blob_index = find_blob_index_by_name(bottom_name);
            if (bottom_blob_index == -1)
            {
                Blob& blob = blobs[blob_index];
                bottom_blob_index = blob_index;
                blob.name = std::string(bottom_name);
//                 fprintf(stderr, "new blob %s\n", bottom_name);
                blob_index++;
            }

            Blob& blob = blobs[bottom_blob_index];
            blob.consumers.push_back(layer_index);
            layer->bottoms[i] = bottom_blob_index;
        }

        layer->tops.resize(top_count);
        for (int i=0; i<top_count; i++)
        {
            Blob& blob = blobs[blob_index];

            char blob_name[257];
            nscan = fscanf(fp, "%256s", blob_name);
            if (nscan != 1)
            {
                continue;
            }

            blob.name = std::string(blob_name);
//             fprintf(stderr, "new blob %s\n", blob_name);
            blob.producer = layer_index;
            layer->tops[i] = blob_index;
            blob_index++;
        }

        // layer specific params
        int pdlr = pd.load_param(fp);
        if (pdlr != 0)
        {
            fprintf(stderr, "ParamDict load_param failed\n");
            continue;
        }

        int lr = layer->load_param(pd);
        if (lr != 0)
        {
            fprintf(stderr, "layer load_param failed\n");
            continue;
        }

        layers[layer_index] = layer;
        layer_index++;
    }
    return 0;
}

int Net::load_param(const char* protopath)
{
    FILE* fp = fopen(protopath, "rb");
    if (!fp)
    {
        fprintf(stderr, "fopen %s failed\n", protopath);
        return -1;
    }

    int ret = load_param(fp);
    fclose(fp);
    return ret;
}

參數字典 ParamDict

index:0~19 對應整形或浮點型數據,
查閱 operation param weight table
在這裏插入圖片描述
index:-23000 減去0~19 對應整形或浮點型數組 , 屬於特殊參數(可能沒有): 一種是k=v的類型;另一種是k=len,v1,v2,v3….(數組類型)。該層在ncnn中是存放到paramDict結構中. 根據讀取的index判斷,如果小於-23300,表示爲數組,那麼等號右邊第一個參數就是數組長度,後面順序就是數組內容,[array size],int,int,…,int或[array size],float,float,…,float,例如:

0=1 1=2.5 -23303=2,2.0,3.0

index爲-23303,表明當前參數爲數組,等號右邊第一個參數爲2,表明數組長度爲2,後面2.0,3.0就是數組的內容

int ParamDict::load_param(const DataReader& dr)
{
    clear();

//     0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0

    // parse each key=value pair
    int id = 0;
    while (dr.scan("%d=", &id) == 1)
    {
        bool is_array = id <= -23300;
        if (is_array)
        {
            id = -id - 23300;
        }

        if (is_array)
        {
            int len = 0;
            int nscan = dr.scan("%d", &len);
            if (nscan != 1)
            {
                fprintf(stderr, "ParamDict read array length failed\n");
                return -1;
            }

            params[id].v.create(len);

            for (int j = 0; j < len; j++)
            {
                char vstr[16];
                nscan = dr.scan(",%15[^,\n ]", vstr);
                if (nscan != 1)
                {
                    fprintf(stderr, "ParamDict read array element failed\n");
                    return -1;
                }

                bool is_float = vstr_is_float(vstr);

                if (is_float)
                {
                    float* ptr = params[id].v;
                    nscan = sscanf(vstr, "%f", &ptr[j]);
                }
                else
                {
                    int* ptr = params[id].v;
                    nscan = sscanf(vstr, "%d", &ptr[j]);
                }
                if (nscan != 1)
                {
                    fprintf(stderr, "ParamDict parse array element failed\n");
                    return -1;
                }

                params[id].type = is_float ? 6 : 5;
            }
        }
        else
        {
            char vstr[16];
            int nscan = dr.scan("%15s", vstr);
            if (nscan != 1)
            {
                fprintf(stderr, "ParamDict read value failed\n");
                return -1;
            }

            bool is_float = vstr_is_float(vstr);

            if (is_float)
                nscan = sscanf(vstr, "%f", &params[id].f);
            else
                nscan = sscanf(vstr, "%d", &params[id].i);
            if (nscan != 1)
            {
                fprintf(stderr, "ParamDict parse value failed\n");
                return -1;
            }

            params[id].type = is_float ? 3 : 2;
        }
    }

    return 0;
}

其中 vstr_is_float函數,原理很簡單,就是判斷數字對應字符串中是否存在小數點’.‘或字母’e’,對應小數的兩種寫法,一種正常的小數點表示法,一種是科學計數法.

static bool vstr_is_float(const char vstr[16])
{
    // look ahead for determine isfloat
    for (int j=0; j<16; j++)
    {
        if (vstr[j] == '\0')
            break;

        if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
            return true;
    }

    return false;
}

net.bin

  +---------+---------+---------+---------+---------+---------+
  | weight1 | weight2 | weight3 | weight4 | ....... | weightN |
  +---------+---------+---------+---------+---------+---------+
  ^         ^         ^         ^
  0x0      0x80      0x140     0x1C0

其中 weight buffer

[flag] (optional)
[raw data]
[padding] (optional)

flag : unsigned int, little-endian, indicating the weight storage type, 0 => float32, 0x01306B47 => float16, otherwise => quantized int8, may be omitted if the layer implementation forced the storage type explicitly
raw data : raw weight data, little-endian, float32 data or float16 data or quantized table and indexes depending on the storage type flag
padding : padding space for 32bit alignment, may be omitted if already aligned

load_model代碼:

int Net::load_model(FILE* fp)
{
    if (layers.empty())
    {
        fprintf(stderr, "network graph not ready\n");
        return -1;
    }

    // load file
    int ret = 0;
    ModelBinFromStdio mb(fp);
    for (size_t i=0; i<layers.size(); i++)
    {
        Layer* layer = layers[i];
        int lret = layer->load_model(mb);
        if (lret != 0)
        {
            fprintf(stderr, "layer load_model %d failed\n", (int)i);
            ret = -1;
            break;
        }
    }
    return ret;
}

int Net::load_model(const char* modelpath)
{
    FILE* fp = fopen(modelpath, "rb");
    if (!fp)
    {
        fprintf(stderr, "fopen %s failed\n", modelpath);
        return -1;
    }
    int ret = load_model(fp);
    fclose(fp);
    return ret;
}

參考:
1 param-and-model-file-structure
2 operation param weight table

發佈了258 篇原創文章 · 獲贊 339 · 訪問量 64萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章