詞向量源碼解析:(3.4)GloVe源碼解析之shuffle

這部分代碼的功能是打亂共現矩陣中三元組的順序。cooccur生成的三元組是排好序的。我沒有嘗試過用排好序的訓練能得到什麼結果。一個很簡單的shuffle方法是把所有的都讀入內存,在內存中打亂。但是在我們的內存不足以裝下所有的三元組的時候怎麼辦?這裏採用了一個兩階段的方法做shuffle,先局部shuffle,得到很多臨時文件。字第二階段,再均勻的從每個臨時文件中讀取三元組,放入內存。shuffle以後再寫出的就是最終shuffle好的共現矩陣了。下面介紹幾個關鍵函數

共現矩陣的三元組的數據結構,和之前一樣

typedef struct cooccur_rec {
    int word1;
    int word2;
    real val;
} CREC;

array是三元組數組,size是數組長度,把內存中的array寫到文件中,不需要像之前cooccur那樣還彙總

/* Write contents of array to binary file */
int write_chunk(CREC *array, long size, FILE *fout) {
    long i = 0;
    for (i = 0; i < size; i++) fwrite(&array[i], sizeof(CREC), 1, fout);
    return 0;
}

對內存中的三元組進行打亂順序shuffle,邏輯很簡單,就是在交換順序。

/* Fisher-Yates shuffle */
void shuffle(CREC *array, long n) {
    long i, j;
    CREC tmp;
    for (i = n - 1; i > 0; i--) {
        j = rand_long(i + 1);
        tmp = array[j];
        array[j] = array[i];
        array[i] = tmp;
    }
}

剛纔說了shuffle分成兩個階段,第一個階段是shuffle_by_chunks,順序的讀入文件中的三元組,局部排序

/* Shuffle large input stream by splitting into chunks */
int shuffle_by_chunks() {
    long i = 0, l = 0;
    int fidcounter = 0;
    char filename[MAX_STRING_LENGTH];
    CREC *array;
    FILE *fin = stdin, *fid;
    array = malloc(sizeof(CREC) * array_size);//存儲三元組
    
    fprintf(stderr,"SHUFFLING COOCCURRENCES\n");
    if (verbose > 0) fprintf(stderr,"array size: %lld\n", array_size);
    sprintf(filename,"%s_%04d.bin",file_head, fidcounter);
    fid = fopen(filename,"w");
    if (fid == NULL) {
        fprintf(stderr, "Unable to open file %s.\n",filename);
        return 1;
    }
    if (verbose > 1) fprintf(stderr, "Shuffling by chunks: processed 0 lines.");
    
    while (1) { //Continue until EOF//循環讀入文件中的三元組
        if (i >= array_size) {// If array is full, shuffle it and save to temporary file//內存滿了的話就寫出
            shuffle(array, i-2);//打亂
            l += i;
            if (verbose > 1) fprintf(stderr, "\033[22Gprocessed %ld lines.", l);
            write_chunk(array,i,fid);//寫出
            fclose(fid);
            fidcounter++;
            sprintf(filename,"%s_%04d.bin",file_head, fidcounter);//再向新的文件寫入
            fid = fopen(filename,"w");
            if (fid == NULL) {
                fprintf(stderr, "Unable to open file %s.\n",filename);
                return 1;
            }
            i = 0;
        }
        fread(&array[i], sizeof(CREC), 1, fin);
        if (feof(fin)) break;
        i++;
    }
    shuffle(array, i-2); //Last chunk may be smaller than array_size
    write_chunk(array,i,fid);
    l += i;
    if (verbose > 1) fprintf(stderr, "\033[22Gprocessed %ld lines.\n", l);
    if (verbose > 1) fprintf(stderr, "Wrote %d temporary file(s).\n", fidcounter + 1);
    fclose(fid);
    free(array);
    return shuffle_merge(fidcounter + 1); // Merge and shuffle together temporary files//進入第二階段
}

第二階段再對之前的臨時文件打亂一次,最後輸出最終打亂的三元組,作爲glove的輸入

int shuffle_merge(int num) {//一共有num個臨時文件
    long i, j, k, l = 0;
    int fidcounter = 0;
    CREC *array;
    char filename[MAX_STRING_LENGTH];
    FILE **fid, *fout = stdout;
    
    array = malloc(sizeof(CREC) * array_size);
    fid = malloc(sizeof(FILE) * num);
    for (fidcounter = 0; fidcounter < num; fidcounter++) { //num = number of temporary files to merge//打開所有的臨時文件
        sprintf(filename,"%s_%04d.bin",file_head, fidcounter);
        fid[fidcounter] = fopen(filename, "rb");
        if (fid[fidcounter] == NULL) {
            fprintf(stderr, "Unable to open file %s.\n",filename);
            return 1;
        }
    }
    if (verbose > 0) fprintf(stderr, "Merging temp files: processed %ld lines.", l);
    
    while (1) { //Loop until EOF in all files
        i = 0;
        //Read at most array_size values into array, roughly array_size/num from each temp file//從每個臨時文件讀入array_size/num個三元組,這樣內存中會基本均勻的有所有臨時文件的三元組,打亂他們再寫出
        for (j = 0; j < num; j++) {
            if (feof(fid[j])) continue;
            for (k = 0; k < array_size / num; k++){//從每個文件讀入固定數目的三元組,保證內存不會滿
                fread(&array[i], sizeof(CREC), 1, fid[j]);
                if (feof(fid[j])) break;
                i++;
            }
        }
        if (i == 0) break;//如果讀不出來了(也就是都讀完了)就結束了
        l += i;
        shuffle(array, i-1); // Shuffles lines between temp files//打亂
        write_chunk(array,i,fout);//寫出
        if (verbose > 0) fprintf(stderr, "\033[31G%ld lines.", l);
    }
    fprintf(stderr, "\033[0GMerging temp files: processed %ld lines.", l);
    for (fidcounter = 0; fidcounter < num; fidcounter++) {
        fclose(fid[fidcounter]);
        sprintf(filename,"%s_%04d.bin",file_head, fidcounter);
        remove(filename);
    }
    fprintf(stderr, "\n\n");
    free(array);
    return 0;
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章