參考資料:
機器學習實戰
'''
@version: 0.0.1
@Author: tqrs
@dev: python3 vscode
@Date: 2019-11-11 11:24:22
@LastEditTime: 2019-11-11 16:56:01
@FilePath: \\機器學習實戰\\11-Apriori\\apriori.py
@Descripttion: 如果一個元素項是不頻繁的,那麼那些包含該元素的超集也是不頻繁的。Apriori算法從單元素項集開始,通過組合滿足最小支持度要求的項集來形成更大的集合。支持度用來度量一個集合在原始數據中出現的頻率
'''
def loadDataSet():
return [[1, 3, 4], [2, 3, 5], [1, 2, 3, 5], [2, 5]]
def createC1(dataSet):
"""
[summary]:構建集合C1,C1是大小爲1的所有候選項集的集合
Arguments:
dataSet List[List] -- 數據集
Returns:
List[frozenset]
"""
C1 = []
for transaction in dataSet:
for item in transaction:
if [item] not in C1:
C1.append([item])
C1.sort()
return list(map(frozenset, C1))
def scanD(Data, Ck, minSupport):
"""
[summary]:滿足最低要求的項集生成集合L1
Arguments:
Data List[set] -- 數據集
Ck List[frozenset] -- 候選項集列表Ck
minSupport float -- 最小支持度
Returns:
retList -- 滿足最低要求的項集
supportData -- 包含支持度的字典
"""
ssCnt = {}
for tid in Data:
for can in Ck:
if can.issubset(tid):
ssCnt[can] = ssCnt.get(can, 0) + 1
numItems = float(len(Data))
retList = []
supportData = {}
for key in ssCnt.keys():
support = ssCnt[key] / numItems
supportData[key] = support
if support >= minSupport:
retList.insert(0, key)
return retList, supportData
def test_scanD():
dataSet = loadDataSet()
C1 = createC1(dataSet)
D = list(map(set, dataSet))
L1, suppData0 = scanD(D, C1, 0.5)
print("L1:", L1)
print("suppData0:", suppData0)
def aprioriGen(Lk, k):
"""
[summary]:組合,向上合併。根據Lk和k輸出所有可能的候選集Ck
Arguments:
Lk List[set] -- 頻繁項集列表
k int -- 返回的項集元素個數
Returns:
retList -- 元素兩兩合併的數據集
"""
retList = []
lenLk = len(Lk)
for i in range(lenLk):
for j in range(i + 1, lenLk):
L1 = list(Lk[i])[:k - 2]
L2 = list(Lk[j])[:k - 2]
L1.sort()
L2.sort()
if L1 == L2:
retList.append(Lk[i] | Lk[j])
return retList
def apriori(dataSet, minSupport=0.5):
"""
[summary]
當集合中項的個數大於0時
構建一個k個項組成的候選項集的列表
檢查數據以確認每個項集都是頻繁的
保留頻繁項集並構建k+1項組成的候選項集的列表
Arguments:
dataSet -- 數據集
Keyword Arguments:
minSupport {float} -- 最小支持度 (default: {0.5})
Returns:
L -- 頻繁項集的全集
supportData -- 所有元素和支持度的全集
"""
C1 = createC1(dataSet)
D = list(map(set, dataSet))
L1, supportData = scanD(D, C1, minSupport)
L = [L1]
k = 2
while (len(L[k - 2]) > 0):
Ck = aprioriGen(L[k - 2], k)
Lk, supK = scanD(D, Ck, minSupport)
supportData.update(supK)
L.append(Lk)
k += 1
return L, supportData
def test_apriori():
dataSet = loadDataSet()
L, suppData = apriori(dataSet)
print("L:", L)
print("suppData:", suppData)
def calcConf(freqSet, H, supportData, brl, minConf=0.7):
"""
[summary]:計算可信度,支持度定義: a -> b = support(a | b) / support(a).
Arguments:
freqSet -- 頻繁項集中的元素,例如: frozenset([2, 3, 5])
H -- 繁項集中的元素的集合 例如: [frozenset([2]), frozenset([3]), frozenset([5])]
supportData {dict} -- 支持度字典
brl {[type]} -- 關聯規則列表的空數組
Keyword Arguments:
minConf {float} -- 最小可信度 (default: {0.7})
Returns:
List -- 滿足最小可信度要求的規則列表
"""
prunedH = []
for conseq in H:
conf = supportData[freqSet] / supportData[freqSet - conseq]
if conf >= minConf:
print(freqSet - conseq, '-->', conseq, 'conf:', conf)
brl.append((freqSet - conseq, conseq, conf))
prunedH.append(conseq)
return prunedH
def rulesFromConseq(freqSet, H, supportData, brl, minConf=0.7):
"""
[summary]:遞歸計算頻繁項集的規則
Arguments:
freqSet -- 頻繁項集中的元素,例如: frozenset([2, 3, 5])
H {[type]} -- 出現在規則右部的元素列表 例如: [frozenset([2]), frozenset([3]), frozenset([5])]
supportData -- 支持度字典
brl {[type]} -- 關聯規則列表的數組
Keyword Arguments:
minConf {float} -- 最小可信度 (default: {0.7})
"""
m = len(H[0])
if (len(freqSet) > (m + 1)):
Hmp1 = aprioriGen(H, m + 1)
Hmp1 = calcConf(freqSet, Hmp1, supportData, brl, minConf)
if (len(Hmp1) > 1):
rulesFromConseq(freqSet, Hmp1, supportData, brl, minConf)
def generateRules(L, supportData, minConf=0.7):
"""
[summary]:生成關聯規則
Arguments:
L {[type]} -- 頻繁項集列表
supportData {[type]} -- 頻繁項集支持度的字典
Keyword Arguments:
minConf {float} -- 最小置信度 (default: {0.7})
Returns:
[List] -- 信度規則列表(關於 (A->B+置信度) 3個字段的組合)
"""
bigRuleList = []
for i in range(1, len(L)):
for freqSet in L[i]:
H1 = [frozenset([item]) for item in freqSet]
if (i > 1):
rulesFromConseq(freqSet, H1, supportData, bigRuleList, minConf)
else:
calcConf(freqSet, H1, supportData, bigRuleList, minConf)
return bigRuleList
def test_rules():
dataSet = loadDataSet()
L, suppData = apriori(dataSet, minSupport=0.5)
print("L:", L)
print("suppData:", suppData)
rules = generateRules(L, suppData, minConf=0.7)
print("rules:", rules)
def test_mushroom():
mushDatSet = [
line.split()
for line in open(r'./11-Apriori/mushroom.dat').readlines()
]
L, suppData = apriori(mushDatSet, minSupport=0.3)
for item in L[3]:
if item.intersection('2'):
print(item)
if __name__ == '__main__':
test_mushroom()