您好,登錄后才能下訂單哦!
本文實例講述了決策樹剪枝算法的python實現方法。分享給大家供大家參考,具體如下:
決策樹是一種依托決策而建立起來的一種樹。在機器學習中,決策樹是一種預測模型,代表的是一種對象屬性與對象值之間的一種映射關系,每一個節點代表某個對象,樹中的每一個分叉路徑代表某個可能的屬性值,而每一個葉子節點則對應從根節點到該葉子節點所經歷的路徑所表示的對象的值。決策樹僅有單一輸出,如果有多個輸出,可以分別建立獨立的決策樹以處理不同的輸出。
ID3算法:ID3算法是決策樹的一種,是基于奧卡姆剃刀原理的,即用盡量用較少的東西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉樹3代,是Ross Quinlan發明的一種決策樹算法,這個算法的基礎就是上面提到的奧卡姆剃刀原理,越是小型的決策樹越優于大的決策樹,盡管如此,也不總是生成最小的樹型結構,而是一個啟發式算法。在信息論中,期望信息越小,那么信息增益就越大,從而純度就越高。ID3算法的核心思想就是以信息增益來度量屬性的選擇,選擇分裂后信息增益最大的屬性進行分裂。該算法采用自頂向下的貪婪搜索遍歷可能的決策空間。
信息熵,將其定義為離散隨機事件出現的概率,一個系統越是有序,信息熵就越低,反之一個系統越是混亂,它的信息熵就越高。所以信息熵可以被認為是系統有序化程度的一個度量。
基尼指數:在CART里面劃分決策樹的條件是采用Gini Index,定義如下:gini(T)=1−sumnj=1p2j。其中,( p_j )是類j在T中的相對頻率,當類在T中是傾斜的時,gini(T)會最小。將T劃分為T1(實例數為N1)和T2(實例數為N2)兩個子集后,劃分數據的Gini定義如下:ginisplit(T)=fracN1Ngini(T1)+fracN2Ngini(T2),然后選擇其中最小的(gini_{split}(T) )作為結點劃分決策樹
具體實現
首先用函數calcShanno計算數據集的香農熵,給所有可能的分類創建字典
def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} # 給所有可能分類創建字典 for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 # 以2為底數計算香農熵 for key in labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * log(prob, 2) return shannonEnt
# 對離散變量劃分數據集,取出該特征取值為value的所有樣本 def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis + 1:]) retDataSet.append(reducedFeatVec) return retDataSet
對連續變量劃分數據集,direction規定劃分的方向, 決定是劃分出小于value的數據樣本還是大于value的數據樣本集
numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0 bestFeature = -1 bestSplitDict = {} for i in range(numFeatures): featList = [example[i] for example in dataSet] # 對連續型特征進行處理 if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int': # 產生n-1個候選劃分點 sortfeatList = sorted(featList) splitList = [] for j in range(len(sortfeatList) - 1): splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0) bestSplitEntropy = 10000 slen = len(splitList) # 求用第j個候選劃分點劃分時,得到的信息熵,并記錄最佳劃分點 for j in range(slen): value = splitList[j] newEntropy = 0.0 subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0) subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1) prob0 = len(subDataSet0) / float(len(dataSet)) newEntropy += prob0 * calcShannonEnt(subDataSet0) prob1 = len(subDataSet1) / float(len(dataSet)) newEntropy += prob1 * calcShannonEnt(subDataSet1) if newEntropy < bestSplitEntropy: bestSplitEntropy = newEntropy bestSplit = j # 用字典記錄當前特征的最佳劃分點 bestSplitDict[labels[i]] = splitList[bestSplit] infoGain = baseEntropy - bestSplitEntropy # 對離散型特征進行處理 else: uniqueVals = set(featList) newEntropy = 0.0 # 計算該特征下每種劃分的信息熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i # 若當前節點的最佳劃分特征為連續特征,則將其以之前記錄的劃分點為界進行二值化處理 # 即是否小于等于bestSplitValue if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int': bestSplitValue = bestSplitDict[labels[bestFeature]] labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue) for i in range(shape(dataSet)[0]): if dataSet[i][bestFeature] <= bestSplitValue: dataSet[i][bestFeature] = 1 else: dataSet[i][bestFeature] = 0 return bestFeature
def chooseBestFeatureToSplit(dataSet, labels): numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0 bestFeature = -1 bestSplitDict = {} for i in range(numFeatures): featList = [example[i] for example in dataSet] # 對連續型特征進行處理 if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int': # 產生n-1個候選劃分點 sortfeatList = sorted(featList) splitList = [] for j in range(len(sortfeatList) - 1): splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0) bestSplitEntropy = 10000 slen = len(splitList) # 求用第j個候選劃分點劃分時,得到的信息熵,并記錄最佳劃分點 for j in range(slen): value = splitList[j] newEntropy = 0.0 subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0) subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1) prob0 = len(subDataSet0) / float(len(dataSet)) newEntropy += prob0 * calcShannonEnt(subDataSet0) prob1 = len(subDataSet1) / float(len(dataSet)) newEntropy += prob1 * calcShannonEnt(subDataSet1) if newEntropy < bestSplitEntropy: bestSplitEntropy = newEntropy bestSplit = j # 用字典記錄當前特征的最佳劃分點 bestSplitDict[labels[i]] = splitList[bestSplit] infoGain = baseEntropy - bestSplitEntropy # 對離散型特征進行處理 else: uniqueVals = set(featList) newEntropy = 0.0 # 計算該特征下每種劃分的信息熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i # 若當前節點的最佳劃分特征為連續特征,則將其以之前記錄的劃分點為界進行二值化處理 # 即是否小于等于bestSplitValue if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int': bestSplitValue = bestSplitDict[labels[bestFeature]] labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue) for i in range(shape(dataSet)[0]): if dataSet[i][bestFeature] <= bestSplitValue: dataSet[i][bestFeature] = 1 else: dataSet[i][bestFeature] = 0 return bestFeature ``def classify(inputTree, featLabels, testVec): firstStr = inputTree.keys()[0] if u'<=' in firstStr: featvalue = float(firstStr.split(u"<=")[1]) featkey = firstStr.split(u"<=")[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(featkey) if testVec[featIndex] <= featvalue: judge = 1 else: judge = 0 for key in secondDict.keys(): if judge == int(key): if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] else: secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] return classLabel
def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote]=0 classCount[vote]+=1 return max(classCount) def testing_feat(feat, train_data, test_data, labels): class_list = [example[-1] for example in train_data] bestFeatIndex = labels.index(feat) train_data = [example[bestFeatIndex] for example in train_data] test_data = [(example[bestFeatIndex], example[-1]) for example in test_data] all_feat = set(train_data) error = 0.0 for value in all_feat: class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value] major = majorityCnt(class_feat) for data in test_data: if data[0] == value and data[1] != major: error += 1.0 # print 'myTree %d' % error return error
測試
error = 0.0 for i in range(len(data_test)): if classify(myTree, labels, data_test[i]) != data_test[i][-1]: error += 1 # print 'myTree %d' % error return float(error) def testingMajor(major, data_test): error = 0.0 for i in range(len(data_test)): if major != data_test[i][-1]: error += 1 # print 'major %d' % error return float(error) **遞歸產生決策樹** ```def createTree(dataSet,labels,data_full,labels_full,test_data,mode): classList=[example[-1] for example in dataSet] if classList.count(classList[0])==len(classList): return classList[0] if len(dataSet[0])==1: return majorityCnt(classList) labels_copy = copy.deepcopy(labels) bestFeat=chooseBestFeatureToSplit(dataSet,labels) bestFeatLabel=labels[bestFeat] if mode == "unpro" or mode == "post": myTree = {bestFeatLabel: {}} elif mode == "prev": if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList), test_data): myTree = {bestFeatLabel: {}} else: return majorityCnt(classList) featValues=[example[bestFeat] for example in dataSet] uniqueVals=set(featValues) if type(dataSet[0][bestFeat]).__name__ == 'unicode': currentlabel = labels_full.index(labels[bestFeat]) featValuesFull = [example[currentlabel] for example in data_full] uniqueValsFull = set(featValuesFull) del (labels[bestFeat]) for value in uniqueVals: subLabels = labels[:] if type(dataSet[0][bestFeat]).__name__ == 'unicode': uniqueValsFull.remove(value) myTree[bestFeatLabel][value] = createTree(splitDataSet \ (dataSet, bestFeat, value), subLabels, data_full, labels_full, splitDataSet \ (test_data, bestFeat, value), mode=mode) if type(dataSet[0][bestFeat]).__name__ == 'unicode': for value in uniqueValsFull: myTree[bestFeatLabel][value] = majorityCnt(classList) if mode == "post": if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data): return majorityCnt(classList) return myTree <div class="se-preview-section-delimiter"></div> ```**讀入數據** ```def load_data(file_name): with open(r"dd.csv", 'rb') as f: df = pd.read_csv(f,sep=",") print(df) train_data = df.values[:11, 1:].tolist() print(train_data) test_data = df.values[11:, 1:].tolist() labels = df.columns.values[1:-1].tolist() return train_data, test_data, labels <div class="se-preview-section-delimiter"></div> ```測試并繪制樹圖 import matplotlib.pyplot as plt decisionNode = dict(box, color='red') # 定義判斷結點形態 leafNode = dict(box, color='grey') # 定義葉結點形態 arrow_args = dict(arrow, color='blue') # 定義箭頭 # 計算樹的葉子節點數量 def getNumLeafs(myTree): numLeafs = 0 firstSides = list(myTree.keys()) firstStr = firstSides[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs # 計算樹的最大深度 def getTreeDepth(myTree): maxDepth = 0 firstSides = list(myTree.keys()) firstStr = firstSides[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth # 畫節點 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \ xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \ bbox=nodeType, arrowprops=arrow_args) # 畫箭頭上的文字 def plotMidText(cntrPt, parentPt, txtString): lens = len(txtString) xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002 yMid = (parentPt[1] + cntrPt[1]) / 2.0 createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstSides = list(myTree.keys()) firstStr = firstSides[0] cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode) plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key)) plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.x0ff = -0.5 / plotTree.totalW plotTree.y0ff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show()
if __name__ == "__main__": train_data, test_data, labels = load_data("dd.csv") data_full = train_data[:] labels_full = labels[:] mode="post" mode = "prev" mode="post" myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode) createPlot(myTree) print(json.dumps(myTree, ensure_ascii=False, indent=4))
選擇mode就可以分別得到三種樹圖
if __name__ == "__main__": train_data, test_data, labels = load_data("dd.csv") data_full = train_data[:] labels_full = labels[:] mode="post" mode = "prev" mode="post" myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode) createPlot(myTree) print(json.dumps(myTree, ensure_ascii=False, indent=4))
選擇mode就可以分別得到三種樹圖
更多關于Python相關內容感興趣的讀者可查看本站專題:《Python數據結構與算法教程》、《Python加密解密算法與技巧總結》、《Python編碼操作技巧總結》、《Python函數使用技巧總結》、《Python字符串操作技巧匯總》及《Python入門與進階經典教程》
希望本文所述對大家Python程序設計有所幫助。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。