計算圖(graph)的遍歷

很久沒有寫博客了,忙並不是藉口,懶纔是理由。
一直想重構CupCnn,寫成一個通用的計算圖,能隨意搭建各種神經網絡(CupCnn只能搭建有一個鏈路的有序的神經網絡),然後把名字也改了,叫CupDnn好了。所以,今天先寫着試下遍歷一個計算圖吧。
先隨便構造一個簡單的計算圖,如下:
在這裏插入圖片描述
關於這個圖的說明:
假設每一個Unit都執行一些計算,把計算結果推送給與他相連接的Unit,下一個Unit對輸入再做計算。也就是說,如果有一個Unit,有三個輸入,那麼爲了完成計算,它必須等待三個輸入都將結果輸出給它,它纔可以執行它的計算。比如說,fifth需要second,third,forth三個圖元的輸出結果,它必須等這三個圖源都完成計算,並將結果遞送給它,它纔可以進行計算。這裏的計算就是打印圖元的名字。

遍歷分爲深度優先和廣度優先,深度優先用遞歸實現,廣度優先用任務隊列實現。
代碼如下:

//define graph unit
typedef struct graph_unit{
    string name;
    bool visit;
    int inputSize;
    int outputSize;
    vector<struct graph_unit *> *outputs;
    vector<struct graph_unit *> *inputs;
}Unit,*PUnit;

PUnit createGraph()
{
    PUnit input = new Unit;
    input->name = "intput";
    input->inputSize = 0;
    input->outputSize = 3;
    input->visit = false;
    input->inputs = nullptr;
    input->outputs = new vector<PUnit>;
    
    PUnit first = new Unit;
    first->name = "first";
    first->inputSize = 1;
    first->outputSize = 1;
    first->visit = false;
    first->inputs = new vector<PUnit>;
    first->outputs = new vector<PUnit>;
    
    PUnit second = new Unit;
    second->name = "second";
    second->inputSize = 1;
    second->outputSize = 1;
    second->visit = false;
    second->inputs = new vector<PUnit>;
    second->outputs = new vector<PUnit>;
    
    PUnit third = new Unit;
    third->name = "third";
    third->inputSize = 1;
    third->outputSize = 1;
    third->visit = false;
    third->inputs = new vector<PUnit>;
    third->outputs = new vector<PUnit>;
    
    PUnit forth = new Unit;
    forth->name = "forth";
    forth->inputSize = 1;
    forth->outputSize = 1;
    forth->visit = false;
    forth->inputs = new vector<PUnit>;
    forth->outputs = new vector<PUnit>;
    
    PUnit fifth = new Unit;
    fifth->name = "fifth";
    fifth->inputSize = 3;
    fifth->outputSize = 0;
    fifth->visit = false;
    fifth->inputs = new vector<PUnit>;
    fifth->outputs = nullptr;
    
    input->outputs->push_back(first);
    input->outputs->push_back(second);
    input->outputs->push_back(third);
    
    first->outputs->push_back(forth);
    
    second->outputs->push_back(fifth);
    third->outputs->push_back(fifth);
    forth->outputs->push_back(fifth);
    
    return input;
}

//廣度優先遍歷圖,釋放所有的資源
void broadScanDestroyGraph(PUnit input)
{
    queue<PUnit> preProcessQueue;
    preProcessQueue.push(input);
    while (!preProcessQueue.empty()) {
        PUnit tmp = preProcessQueue.front();
        preProcessQueue.pop();
        if(!tmp)continue;
        if(tmp->outputs){
            for(int i=0;i<tmp->outputs->size();i++){
                PUnit p = tmp->outputs->at(i);
                p->inputSize--;
                if(!p->inputSize){
                    preProcessQueue.push(p);
                }
            }
        }
        cout<<"destory: "<<tmp->name<<endl;
        if(tmp->inputs){
            tmp->inputs->clear();
            delete tmp->inputs;
        }
        if(tmp->outputs){
            tmp->outputs->clear();
            delete tmp->outputs;
        }
        if(tmp)delete tmp;
        tmp = nullptr;
    }
}
//廣度優先遍歷圖,執行計算
//採用任務隊列
void broadScanGraph(PUnit input)
{
    queue<PUnit> preProcessQueue;
    preProcessQueue.push(input);
    while (!preProcessQueue.empty()) {
        PUnit tmp = preProcessQueue.front();
        preProcessQueue.pop();
        if(!tmp)continue;
        tmp->visit = true;
        cout<<tmp->name<<endl;
        if(!tmp->outputSize)continue;
        for(int i=0;i<tmp->outputs->size();i++){
            PUnit p = tmp->outputs->at(i);
            p->inputs->push_back(tmp);
            if(p->inputSize == p->inputs->size() && !p->visit){
                preProcessQueue.push(p);
            }
        }
    }
}
//深度優先遍歷圖,釋放資源
void deepScanDestoryGraph(PUnit input)
{
    if(!input->outputSize)return;
    for(int i=0;i<input->outputs->size();i++){
        PUnit p = input->outputs->at(i);
        p->inputSize--;
        if(!p->inputSize){
            deepScanDestoryGraph(p);
            cout<<"destory: "<<p->name<<endl;
            if(p->inputs){
                p->inputs->clear();
                delete p->inputs;
            }
            if(p->outputs){
                p->outputs->clear();
                delete p->outputs;
            }
            if(p)delete p;
            p = nullptr;
        }
    }
}
//深度優先遍歷圖,執行計算
//採用遞歸
void deepScanGraph(PUnit input)
{
    if(!input->outputSize)return;
    for(int i=0;i<input->outputs->size();i++){
        PUnit p = input->outputs->at(i);
        p->inputs->push_back(input);
        if(p->inputSize == p->inputs->size() && !p->visit){
            cout<<p->name<<endl;
            p->visit = true;
            deepScanGraph(p);
        }
    }
}

int main(int argc, const char * argv[]) {
    // insert code here...
    cout<<"-------broad scan grapp--------"<<endl;
    PUnit input = createGraph();
    broadScanGraph(input);
    broadScanDestroyGraph(input);
    
    cout<<"-------deep scan grapp--------"<<endl;
    input = createGraph();
    deepScanGraph(input);
    deepScanDestoryGraph(input);
    return 0;
}

結果如下:

-------broad scan grapp--------
intput
first
second
third
forth
fifth
destory: intput
destory: first
destory: second
destory: third
destory: forth
destory: fifth
-------deep scan grapp--------
first
forth
second
third
fifth
destory: forth
destory: first
destory: second
destory: fifth
destory: third
Program ended with exit code: 0

再增加一個圖元,構造更加複雜一點看看對不對:
在這裏插入圖片描述
構造代碼如下:

PUnit createGraph2()
{
    PUnit input = new Unit;
    input->name = "intput";
    input->inputSize = 0;
    input->outputSize = 3;
    input->visit = false;
    input->inputs = nullptr;
    input->outputs = new vector<PUnit>;
    
    PUnit first = new Unit;
    first->name = "first";
    first->inputSize = 1;
    first->outputSize = 1;
    first->visit = false;
    first->inputs = new vector<PUnit>;
    first->outputs = new vector<PUnit>;
    
    PUnit second = new Unit;
    second->name = "second";
    second->inputSize = 1;
    second->outputSize = 1;
    second->visit = false;
    second->inputs = new vector<PUnit>;
    second->outputs = new vector<PUnit>;
    
    PUnit third = new Unit;
    third->name = "third";
    third->inputSize = 1;
    third->outputSize = 1;
    third->visit = false;
    third->inputs = new vector<PUnit>;
    third->outputs = new vector<PUnit>;
    
    PUnit forth = new Unit;
    forth->name = "forth";
    forth->inputSize = 1;
    forth->outputSize = 1;
    forth->visit = false;
    forth->inputs = new vector<PUnit>;
    forth->outputs = new vector<PUnit>;
    
    PUnit fifth = new Unit;
    fifth->name = "fifth";
    fifth->inputSize = 3;
    fifth->outputSize = 1;
    fifth->visit = false;
    fifth->inputs = new vector<PUnit>;
    fifth->outputs = new vector<PUnit>;

    PUnit sixth = new Unit;
    sixth->name = "sixth";
    sixth->inputSize = 2;
    sixth->outputSize = 0;
    sixth->visit = false;
    sixth->inputs = new vector<PUnit>;
    sixth->outputs = nullptr;

    
    input->outputs->push_back(first);
    input->outputs->push_back(second);
    input->outputs->push_back(third);
    
    first->outputs->push_back(forth);
    
    second->outputs->push_back(fifth);
    third->outputs->push_back(fifth);
    forth->outputs->push_back(fifth);
    
    fifth->outputs->push_back(sixth);
    forth->outputs->push_back(sixth);
    
    return input;
}

執行結果如下:

-------broad scan grapp--------
intput
first
second
third
forth
fifth
sixth
destory: intput
destory: first
destory: second
destory: third
destory: forth
destory: fifth
destory: sixth
-------deep scan grapp--------
first
forth
second
third
fifth
sixth
destory: forth
destory: first
destory: second
destory: sixth
destory: fifth
destory: third
Program ended with exit code: 0

可見還是按預期運行的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章