構建貝葉斯網絡——K2算法(Python實現)

 K2算法僞代碼如下:主要思想是貪婪思想,這樣也可以保證每次插入的邊不構成環。

from factor import fact
import math
import random

File = open('C:\\Users\\lenovo\Desktop\\網絡科學導論cpp代碼\\時序網絡數據集\\email-Eu-core-temporal-Dept3.txt', 'r')
lineList = File.readlines()
print(len(lineList))
#該數據文件共有12216條記錄
lasttime = -1
num = 0
NodeSet = set()
for line in lineList:
    #首先要對line進行split處理
    curlineList = line.split(' ')
    #print(curlineList)
    #print(type(curlineList))
    ProNode = int(curlineList[0])
    PostNode = int(curlineList[1])
    if ProNode < 20:
        ProNode += 1
    if PostNode < 20:
        PostNode += 1
    NodeSet.add(ProNode)
    NodeSet.add(PostNode)
    curtime = int(curlineList[2][:len(curlineList[2]) - 1])
    if (lasttime == curtime):
        continue
    lasttime = curtime
    num += 1

File.close()
print(lasttime)
#print('時間戳共有:' + str(num) + '條')
#時間戳共有8911條

#現在考慮將89*89個變量一維化
NodeList = list(NodeSet)
print(len(NodeList))
NodeList.sort()
print(NodeList)
N = NodeList[len(NodeList)-1]
print('N:'+str(N))

#現在考慮如何給邊變量建立索引
Index = []
Index.append(0)
EdgeToIndex = dict()
for i in range(1, N+1):
    for j in range(i + 1, N+1):
        curelem = tuple([i, j])
        Index.append(curelem)
        EdgeToIndex[curelem] = len(Index) - 1

print(len(Index))
print(Index)
#print('EdgeToIndex'+str('1,89:')+str(EdgeToIndex[tuple([1, 89])]))
#這樣共有3916個無向邊變量

#下一步,我們對時間序列進行切片處理(以半天爲一個測試點)
Data = [[] for i in range(803)]

#print(len(Data))
#print(Data)
for i in range(803):
    for j in range(3917):
        Data[i].append(0)


#print(len(Data[1]))
#print(Data[1])

#切片
T = int(lasttime / 803)
print(T)

#故以56426爲時間間隔
File = open('C:\\Users\\lenovo\Desktop\\網絡科學導論cpp代碼\\時序網絡數據集\\email-Eu-core-temporal-Dept3.txt', 'r')
lineList = File.readlines()
print(len(lineList))
#該數據文件共有12216條記錄
num = 0
NodeSet = set()
possibleDateNumber = 0
for line in lineList:
    curlineList = line.split(' ')
    ProNode = int(curlineList[0])
    PostNode = int(curlineList[1])
    if ProNode < 20:
        ProNode += 1
    if PostNode < 20:
        PostNode += 1
    if PostNode > ProNode:
        temp = ProNode
        ProNode = PostNode
        PostNode = temp
    curtime = int(curlineList[2][:len(curlineList[2]) - 1])
    for t in range(possibleDateNumber, 803):
        if t*T < curtime < (t+1)*T:
            Data[t][EdgeToIndex[tuple([PostNode, ProNode])]] = 1
            possibleDateNumber = t
            break

    num += 1

print(Data[0])
File.close()

NewFile = open('C:\\Users\\lenovo\\Desktop\\網絡科學導論cpp代碼\\時序網絡數據集\\Data.txt', 'w')

for i in range(0, 803):
    ranList = random.sample(range(3916), 500)
    for j in ranList:
        Data[i][j] = 1

for i in range(0, 803):
    for j in range(0, 3917):
        NewFile.write(str(Data[i][j])+'|')
    NewFile.write('\n')
NewFile.close()
'''
for i in range(3917):
    if (Data[0][i]==1):
        #print(i)
        print(Index[i])
'''

#至此,數據集Data構造完畢

#下面我們考慮給這3916個連邊變量構建貝葉斯置信網絡——K2算法實現
u = 10
ParentSet = [[] for i in range(3917)]


#這裏定義關鍵的Maximize函數
def maximize(ii, Parents):
    #print('ii:'+str(ii))
    #這個函數需要遍歷點ii的所有前驅節點,然後選擇使g最大的z和g值作爲返回值
    G = -1000000000000000000000
    returnZ = -1
    for zz in range(1, ii):
        if zz in Parents:
            continue
        #如果不在的話就要考慮計算它的g值,首先求其父節點的取值情況
        num = len(Parents) + 1
        proding = 0
        for j in range(2**num):
            v = bin(j)
            w = dict()
            for k in range(1, num - len(v) + 3):
                w[k] = 0
            cur = 2
            for k in range(num - len(v) + 3, num + 1):
                w[k] = int(v[cur])
                cur += 1
            #print('w:')
            #print(w)
            #先計算N_ijk
            temp = 1
            N_ij0 = 0
            N_ij1 = 0
            for kkk in range(len(Data)-200):
                line = Data[kkk]
                #print(line)
                flag = 1
                s = 1
                for k in Parents:
                    if line[k] != w[s]:
                        flag = 0
                    s += 1
                if line[zz] != w[s]:
                    flag = 0
                if flag == 0:
                    continue
                if line[ii] == 0:
                    N_ij0 += 1
                if line[ii] == 1:
                    N_ij1 += 1
            #print('N_ij1: '+str(N_ij1))
            proding += math.log(fact(N_ij0)) + math.log(fact(N_ij1))
            proding -= math.log(fact(N_ij0 + N_ij1 + 1))
            #proding = format(proding, '.4f')
            #print(proding)

        '''
        print('************')
        print('G:' + str(G))
        print('proding:' + str(proding))
        '''

        if G < proding:
            G = proding
            #print('ii: '+str(ii))
            #print('G:'+str(G))
            #print('proding:'+str(proding))
            returnZ = zz
    return returnZ, G


NewFile = open('C:\\Users\\lenovo\\Desktop\\網絡科學導論cpp代碼\\時序網絡數據集\\network.txt', 'w')

for i in range(1, 500):
    P_old = -1000000000000000000
    OKToProceed = True
    while OKToProceed and (len(ParentSet[i]) < u):
        z, P_new = maximize(i, ParentSet[i])
        if P_new > P_old:
            P_old = P_new
            ParentSet[i].append(z)
        else:
            OKToProceed = False
    print("Node"+str(i)+":  Parents of this node:")
    print(ParentSet[i])
    for k in range(len(ParentSet[i])):
        NewFile.write(str(ParentSet[i][k])+'|'+str(i)+'\n')

NewFile.close()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章