指点成金-最美分享吧

登录

关于决策树可视化的treePlotter(学习笔记)

佚名 举报

篇首语:本文由小编为大家整理,主要介绍了关于决策树可视化的treePlotter(学习笔记)相关的知识,希望对你有一定的参考价值。

网上的版本好像好久都没更新了treePlotter是没有人用了么。今天学习的时候发现有些地方已经改了,我改的是在python 3.6 上的运行版本,需要导入matplotlib.pyplot

import matplotlib.pyplot as plt# 定义决策树决策结果属性descisionNode = dict(box, fc="0.8")leafNode = dict(box, fc="0.8")arrow_args = dict(arrow)
# myTree = {"no surfacing": {0: "no", 1: {"flippers": {0: "no", 1: "yes"}}}}def plotNode(nodeTxt, centerPt, parentPt, nodeType): # nodeTxt为要显示的文本,centerNode为文本中心点, nodeType为箭头所在的点, parentPt为指向文本的点 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords="axes fraction", xytext=centerPt, textcoords="axes fraction", va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)# def createPlot():# fig = plt.figure(1, facecolor="white")# fig.clf()# # createPlot.ax1为全局变量,绘制图像句柄# # frameon表示是否绘制坐标轴矩形# createPlot.ax1 = plt.subplot(111, frameon=False)# plotNode("a decision node", (0.5, 0.1), (0.1, 0.5), descisionNode)# plotNode("a leaf node", (0.8, 0.1), (0.3, 0.8), leafNode)# plt.show()# 这个是用来测试的# -----------分割线-------------# 获取树的叶子数和树的深度def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == "dict": numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafsdef getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] # 这个是改的地方,原来myTree.keys()返回的是dict_keys类,不是列表,运行会报错。有好几个地方这样 secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == "dict": thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth# ---------分割线-------------# 制图def createPlot(inTree): fig = plt.figure(1, facecolor="white") fig.clf() axprops = {"xticks": None, "yticks": None} createPlot.ax1 = plt.subplot(111, frameon=False) plotTree.totalW = float(getNumLeafs(inTree)) # 全局变量宽度 = 叶子数目 plotTree.totalD = float(getTreeDepth(inTree)) # 全局变量高度 = 深度 plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), "") plt.show()def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] # cntrPt文本中心点, parentPt指向文本中心的点 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, descisionNode) seconDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in seconDict.keys(): if type(seconDict[key]).__name__ == "dict": plotTree(seconDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(seconDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)# createPlot(myTree)

  这个treePlotter导入了就可以把原来得到的决策树模型导入啦,而且要注意是以字典形式导入,所以保存和导入文件的时候最好用json。

发布5分钟之后,突然发现已经有人改过了,那就只算是个学习笔记吧 -  -

 

以上是关于关于决策树可视化的treePlotter(学习笔记)的主要内容,如果未能解决你的问题,请参考以下文章