在Jupyter笔记本中绘制交互式决策树

r0f*_*0f1 37 python machine-learning decision-tree scikit-learn jupyter

有没有办法在Jupyter笔记本中绘制决策树,以便我可以交互式地探索它的节点?我在考虑这样的事情DT.这是KNIME的一个例子.

我找到了https://planspace.org/20151129-see_sklearn_trees_with_d3/https://bl.ocks.org/ajschumacher/65eda1df2b0dd2cf616f,我知道你可以在Jupyter中运行d3,但是我没有找到任何包,那样做.

Moh*_*hif 14

使用Jupyter Notebook中的d3js更新了可折叠图形的答案

在笔记本电脑中启动第一个单元

%%html
<div id="d3-example"></div>
<style>

.node circle {
  cursor: pointer;
  stroke: #3182bd;
  stroke-width: 1.5px;
}

.node text {
  font: 10px sans-serif;
  pointer-events: none;
  text-anchor: middle;
}

line.link {
  fill: none;
  stroke: #9ecae1;
  stroke-width: 1.5px;
}
</style>
Run Code Online (Sandbox Code Playgroud)

笔记本电脑中的第一个电池结束

在笔记本中开始第二个单元格

%%javascript
// We load the d3.js library from the Web.
require.config({paths:
    {d3: "http://d3js.org/d3.v3.min"}});
require(["d3"], function(d3) {
  // The code in this block is executed when the
  // d3.js library has been loaded.

  // First, we specify the size of the canvas
  // containing the visualization (size of the
  // <div> element).
  var width = 960,
    height = 500,
    root;

  // We create a color scale.
  var color = d3.scale.category10();

  // We create a force-directed dynamic graph layout.
//   var force = d3.layout.force()
//     .charge(-120)
//     .linkDistance(30)
//     .size([width, height]);
    var force = d3.layout.force()
    .linkDistance(80)
    .charge(-120)
    .gravity(.05)
    .size([width, height])
    .on("tick", tick);
var svg = d3.select("body").append("svg")
    .attr("width", width)
    .attr("height", height);

var link = svg.selectAll(".link"),
    node = svg.selectAll(".node");

  // In the <div> element, we create a <svg> graphic
  // that will contain our interactive visualization.
 var svg = d3.select("#d3-example").select("svg")
  if (svg.empty()) {
    svg = d3.select("#d3-example").append("svg")
          .attr("width", width)
          .attr("height", height);
  }
var link = svg.selectAll(".link"),
    node = svg.selectAll(".node");
  // We load the JSON file.
  d3.json("graph2.json", function(error, json) {
    // In this block, the file has been loaded
    // and the 'graph' object contains our graph.
 if (error) throw error;
else
    test(1);
root = json;
      test(2);
      console.log(root);
  update();



  });
    function test(rr){console.log('yolo'+String(rr));}

function update() {
    test(3);
  var nodes = flatten(root),
      links = d3.layout.tree().links(nodes);

  // Restart the force layout.
  force
      .nodes(nodes)
      .links(links)
      .start();

  // Update links.
  link = link.data(links, function(d) { return d.target.id; });

  link.exit().remove();

  link.enter().insert("line", ".node")
      .attr("class", "link");

  // Update nodes.
  node = node.data(nodes, function(d) { return d.id; });

  node.exit().remove();

  var nodeEnter = node.enter().append("g")
      .attr("class", "node")
      .on("click", click)
      .call(force.drag);

  nodeEnter.append("circle")
      .attr("r", function(d) { return Math.sqrt(d.size) / 10 || 4.5; });

  nodeEnter.append("text")
      .attr("dy", ".35em")
      .text(function(d) { return d.name; });

  node.select("circle")
      .style("fill", color);
}
    function tick() {
  link.attr("x1", function(d) { return d.source.x; })
      .attr("y1", function(d) { return d.source.y; })
      .attr("x2", function(d) { return d.target.x; })
      .attr("y2", function(d) { return d.target.y; });

  node.attr("transform", function(d) { return "translate(" + d.x + "," + d.y + ")"; });
}
          function color(d) {
  return d._children ? "#3182bd" // collapsed package
      : d.children ? "#c6dbef" // expanded package
      : "#fd8d3c"; // leaf node
}
      // Toggle children on click.
function click(d) {
  if (d3.event.defaultPrevented) return; // ignore drag
  if (d.children) {
    d._children = d.children;
    d.children = null;
  } else {
    d.children = d._children;
    d._children = null;
  }
  update();
}
    function flatten(root) {
  var nodes = [], i = 0;

  function recurse(node) {
    if (node.children) node.children.forEach(recurse);
    if (!node.id) node.id = ++i;
    nodes.push(node);
  }

  recurse(root);
  return nodes;
}

});
Run Code Online (Sandbox Code Playgroud)

笔记本电脑中的第二个电池结束

graph2.json的内容

   {
 "name": "flare",
 "children": [
  {
   "name": "analytics"
    },
    {
   "name": "graph"
    }
   ]
}
Run Code Online (Sandbox Code Playgroud)

图表 在此输入图像描述

单击flare,这是根节点,其他节点将崩溃

在此输入图像描述

这里使用的笔记本的Github存储库:ipython笔记本中的可折叠树

参考

老答案

我在这里找到了这个教程,用于Jupyter Notebook中Decision Tree的交互式可视化.

安装graphviz

这有两个步骤:第1步:使用pip为python安装graphviz

pip install graphviz
Run Code Online (Sandbox Code Playgroud)

第2步:然后你必须单独安装graphviz.检查此链接.然后根据您的系统操作系统,您需要相应地设置路径:

对于Windows和Mac OS,请检查此链接.对于Linux/Ubuntu,请检查此链接

安装ipywidgets

用pip

pip install ipywidgets
jupyter nbextension enable --py widgetsnbextension
Run Code Online (Sandbox Code Playgroud)

使用conda

conda install -c conda-forge ipywidgets
Run Code Online (Sandbox Code Playgroud)

现在为代码

from IPython.display import SVG
from graphviz import Source
from sklearn.datasets load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn import tree
from ipywidgets import interactive
from IPython.display import display                               
Run Code Online (Sandbox Code Playgroud)

加载数据集,例如在这种情况下称为iris数据集

data = load_iris()

#Get the feature matrix
features = data.data

#Get the labels for the sampels
target_label = data.target

#Get feature names
feature_names = data.feature_names
Run Code Online (Sandbox Code Playgroud)

**绘制决策树的函数**

def plot_tree(crit, split, depth, min_split, min_leaf=0.17):
    classifier = DecisionTreeClassifier(random_state = 123, criterion = crit, splitter = split, max_depth = depth, min_samples_split=min_split, min_samples_leaf=min_leaf)
    classifier.fit(features, target_label)

    graph = Source(tree.export_graphviz(classifier, out_file=None, feature_names=feature_names, class_names=['0', '1', '2'], filled = True))

    display(SVG(graph.pipe(format='svg')))
return classifier
Run Code Online (Sandbox Code Playgroud)

调用该函数

decision_plot = interactive(plot_tree, crit = ["gini", "entropy"], split = ["best", "random"]  , depth=[1, 2, 3, 4, 5, 6, 7], min_split=(0.1,1), min_leaf=(0.1,0.2,0.3,0.5))

display(decision_plot)
Run Code Online (Sandbox Code Playgroud)

您将获得以下图表 在此输入图像描述

您可以通过修改以下值在输出单元格中以交互方式更改参数

在此输入图像描述

关于相同数据但不同参数的另一个决策树 在此输入图像描述

参考文献:

  • 不幸的是,这不是我想要的答案。您描述了如何使用不同的输入参数构建不同的决策树。我对探索单一决策树很感兴趣。也就是说,交互式地折叠和扩展决策树节点以了解其所做的预测。此外,我的决策树可能有非常大(10-100 个)的节点。 (2认同)

Ank*_*hta 6

1.如果你只是想在Jupyter中使用D3,这是一个教程:https://medium.com/@stallonejacob/d3-in-juypter-notebook-685d6dca75c8

在此输入图像描述

在此输入图像描述

2.为了构建交互式决策树,这是另一个有趣的GUI工具包,称为TMVAGui.

在这里,代码只是一行代码: factory.DrawDecisionTree(dataset, "BDT")

https://indico.cern.ch/event/572131/contributions/2315243/attachments/1343269/2023816/gsoc16_4thpresentation.pdf


kam*_*kam 1

有一个名为 pydot 的模块。您可以创建图形并添加边来制作决策树。

import pydot # 

graph = pydot.Dot(graph_type='graph')
edge1 = pydot.Edge('1', '2', label = 'edge1')
edge2 = pydot.Edge('1', '3', label = 'edge2')
graph.add_edge(edge1)
graph.add_edge(edge2)

graph.write_png('my_graph.png')
Run Code Online (Sandbox Code Playgroud)

这是一个输出决策树的 png 文件的示例。希望这可以帮助!