R如何使用插入符包可视化混淆矩阵

shi*_*ish 8 r confusion-matrix

我想把我放在混淆矩阵中的数据可视化.有没有一个函数我可以简单地把混淆矩阵和它可视化它(绘制它)?

我想做的例子(Matrix $ nnet只是一个包含分类结果的表):

Confusion$nnet <- confusionMatrix(Matrix$nnet)
plot(Confusion$nnet)
Run Code Online (Sandbox Code Playgroud)

我的混乱$ nnet $表看起来像这样:

    prediction (I would also like to get rid of this string, any help?)
    1  2
1   42 6
2   8 28
Run Code Online (Sandbox Code Playgroud)

Rob*_*ski 18

你可以使用内置的fourfoldplot.例如,

ctable <- as.table(matrix(c(42, 6, 8, 28), nrow = 2, byrow = TRUE))
fourfoldplot(ctable, color = c("#CC6666", "#99CC99"),
             conf.level = 0, margin = 1, main = "Confusion Matrix")
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

  • 对混淆矩阵使用四重图是一个坏主意,因为这种图是根据行和列边际总数加权的。你能看到你对角的数量是 42 和 28,但大小/面积无法区分吗?四重图通常用于分析优势比,无论独立频率是什么,默认加权都有助于实现这一点。如果将其用于二元混淆矩阵,则可能完全具有误导性。您可能会错过一个事实,即您的 FP 或 FN 率很差。你可以通过设置 std = "all.max" 来解决这个问题 (3认同)

RLa*_*ave 17

您可以使用函数conf_mat()yardstickautoplot()在几排得到一个相当不错的结果。

另外,您仍然可以使用基本的ggplot语法来修复样式。

library(yardstick)
library(ggplot2)


# The confusion matrix from a single assessment set (i.e. fold)
cm <- conf_mat(truth_predicted, obs, pred)

autoplot(cm, type = "heatmap") +
  scale_fill_gradient(low="#D6EAF8",high = "#2E86C1")
Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明


作为进一步自定义的示例,使用ggplotsintax 您还可以添加回图例:

+ theme(legend.position = "right")
Run Code Online (Sandbox Code Playgroud)

更改图例的名称也很容易: + labs(fill="legend_name")

在此处输入图片说明

数据示例:

set.seed(123)
truth_predicted <- data.frame(
  obs = sample(0:1,100, replace = T),
  pred = sample(0:1,100, replace = T)
)
truth_predicted$obs <- as.factor(truth_predicted$obs)
truth_predicted$pred <- as.factor(truth_predicted$pred)
Run Code Online (Sandbox Code Playgroud)


小智 14

我真的很喜欢@Cyber​​netic 漂亮的混淆矩阵可视化,并做了两个调整以希望进一步改进它。

1) 我用类的实际值交换了 Class1 和 Class2。2)我用基于百分位数生成红色(未命中)和绿色(命中)的函数替换橙色和蓝色。这个想法是快速查看问题/成功的位置及其大小。

截图和代码:

混淆矩阵代码更新

draw_confusion_matrix <- function(cm) {

  total <- sum(cm$table)
  res <- as.numeric(cm$table)

  # Generate color gradients. Palettes come from RColorBrewer.
  greenPalette <- c("#F7FCF5","#E5F5E0","#C7E9C0","#A1D99B","#74C476","#41AB5D","#238B45","#006D2C","#00441B")
  redPalette <- c("#FFF5F0","#FEE0D2","#FCBBA1","#FC9272","#FB6A4A","#EF3B2C","#CB181D","#A50F15","#67000D")
  getColor <- function (greenOrRed = "green", amount = 0) {
    if (amount == 0)
      return("#FFFFFF")
    palette <- greenPalette
    if (greenOrRed == "red")
      palette <- redPalette
    colorRampPalette(palette)(100)[10 + ceiling(90 * amount / total)]
  }

  # set the basic layout
  layout(matrix(c(1,1,2)))
  par(mar=c(2,2,2,2))
  plot(c(100, 345), c(300, 450), type = "n", xlab="", ylab="", xaxt='n', yaxt='n')
  title('CONFUSION MATRIX', cex.main=2)

  # create the matrix 
  classes = colnames(cm$table)
  rect(150, 430, 240, 370, col=getColor("green", res[1]))
  text(195, 435, classes[1], cex=1.2)
  rect(250, 430, 340, 370, col=getColor("red", res[3]))
  text(295, 435, classes[2], cex=1.2)
  text(125, 370, 'Predicted', cex=1.3, srt=90, font=2)
  text(245, 450, 'Actual', cex=1.3, font=2)
  rect(150, 305, 240, 365, col=getColor("red", res[2]))
  rect(250, 305, 340, 365, col=getColor("green", res[4]))
  text(140, 400, classes[1], cex=1.2, srt=90)
  text(140, 335, classes[2], cex=1.2, srt=90)

  # add in the cm results
  text(195, 400, res[1], cex=1.6, font=2, col='white')
  text(195, 335, res[2], cex=1.6, font=2, col='white')
  text(295, 400, res[3], cex=1.6, font=2, col='white')
  text(295, 335, res[4], cex=1.6, font=2, col='white')

  # add in the specifics 
  plot(c(100, 0), c(100, 0), type = "n", xlab="", ylab="", main = "DETAILS", xaxt='n', yaxt='n')
  text(10, 85, names(cm$byClass[1]), cex=1.2, font=2)
  text(10, 70, round(as.numeric(cm$byClass[1]), 3), cex=1.2)
  text(30, 85, names(cm$byClass[2]), cex=1.2, font=2)
  text(30, 70, round(as.numeric(cm$byClass[2]), 3), cex=1.2)
  text(50, 85, names(cm$byClass[5]), cex=1.2, font=2)
  text(50, 70, round(as.numeric(cm$byClass[5]), 3), cex=1.2)
  text(70, 85, names(cm$byClass[6]), cex=1.2, font=2)
  text(70, 70, round(as.numeric(cm$byClass[6]), 3), cex=1.2)
  text(90, 85, names(cm$byClass[7]), cex=1.2, font=2)
  text(90, 70, round(as.numeric(cm$byClass[7]), 3), cex=1.2)

  # add in the accuracy information 
  text(30, 35, names(cm$overall[1]), cex=1.5, font=2)
  text(30, 20, round(as.numeric(cm$overall[1]), 3), cex=1.4)
  text(70, 35, names(cm$overall[2]), cex=1.5, font=2)
  text(70, 20, round(as.numeric(cm$overall[2]), 3), cex=1.4)
}
Run Code Online (Sandbox Code Playgroud)


Cyb*_*tic 13

您可以使用r中的rect功能来布局混淆矩阵.在这里,我们将创建一个函数,允许用户传入由插入符包创建的cm对象,以产生视觉效果.

让我们首先创建一个在插入符演示中完成的评估数据集:

# construct the evaluation dataset
set.seed(144)
true_class <- factor(sample(paste0("Class", 1:2), size = 1000, prob = c(.2, .8), replace = TRUE))
true_class <- sort(true_class)
class1_probs <- rbeta(sum(true_class == "Class1"), 4, 1)
class2_probs <- rbeta(sum(true_class == "Class2"), 1, 2.5)
test_set <- data.frame(obs = true_class,Class1 = c(class1_probs, class2_probs))
test_set$Class2 <- 1 - test_set$Class1
test_set$pred <- factor(ifelse(test_set$Class1 >= .5, "Class1", "Class2"))
Run Code Online (Sandbox Code Playgroud)

现在让我们使用插入符来计算混淆矩阵:

# calculate the confusion matrix
cm <- confusionMatrix(data = test_set$pred, reference = test_set$obs)
Run Code Online (Sandbox Code Playgroud)

现在我们创建一个函数,根据需要布置矩形,以更具视觉吸引力的方式展示混淆矩阵:

draw_confusion_matrix <- function(cm) {

  layout(matrix(c(1,1,2)))
  par(mar=c(2,2,2,2))
  plot(c(100, 345), c(300, 450), type = "n", xlab="", ylab="", xaxt='n', yaxt='n')
  title('CONFUSION MATRIX', cex.main=2)

  # create the matrix 
  rect(150, 430, 240, 370, col='#3F97D0')
  text(195, 435, 'Class1', cex=1.2)
  rect(250, 430, 340, 370, col='#F7AD50')
  text(295, 435, 'Class2', cex=1.2)
  text(125, 370, 'Predicted', cex=1.3, srt=90, font=2)
  text(245, 450, 'Actual', cex=1.3, font=2)
  rect(150, 305, 240, 365, col='#F7AD50')
  rect(250, 305, 340, 365, col='#3F97D0')
  text(140, 400, 'Class1', cex=1.2, srt=90)
  text(140, 335, 'Class2', cex=1.2, srt=90)

  # add in the cm results 
  res <- as.numeric(cm$table)
  text(195, 400, res[1], cex=1.6, font=2, col='white')
  text(195, 335, res[2], cex=1.6, font=2, col='white')
  text(295, 400, res[3], cex=1.6, font=2, col='white')
  text(295, 335, res[4], cex=1.6, font=2, col='white')

  # add in the specifics 
  plot(c(100, 0), c(100, 0), type = "n", xlab="", ylab="", main = "DETAILS", xaxt='n', yaxt='n')
  text(10, 85, names(cm$byClass[1]), cex=1.2, font=2)
  text(10, 70, round(as.numeric(cm$byClass[1]), 3), cex=1.2)
  text(30, 85, names(cm$byClass[2]), cex=1.2, font=2)
  text(30, 70, round(as.numeric(cm$byClass[2]), 3), cex=1.2)
  text(50, 85, names(cm$byClass[5]), cex=1.2, font=2)
  text(50, 70, round(as.numeric(cm$byClass[5]), 3), cex=1.2)
  text(70, 85, names(cm$byClass[6]), cex=1.2, font=2)
  text(70, 70, round(as.numeric(cm$byClass[6]), 3), cex=1.2)
  text(90, 85, names(cm$byClass[7]), cex=1.2, font=2)
  text(90, 70, round(as.numeric(cm$byClass[7]), 3), cex=1.2)

  # add in the accuracy information 
  text(30, 35, names(cm$overall[1]), cex=1.5, font=2)
  text(30, 20, round(as.numeric(cm$overall[1]), 3), cex=1.4)
  text(70, 35, names(cm$overall[2]), cex=1.5, font=2)
  text(70, 20, round(as.numeric(cm$overall[2]), 3), cex=1.4)
}  
Run Code Online (Sandbox Code Playgroud)

最后,传入我们在使用插入符来创建混淆矩阵时计算的cm对象:

draw_confusion_matrix(cm)
Run Code Online (Sandbox Code Playgroud)

以下是结果:

从插入符号包中可视化混淆矩阵


s_t*_*s_t 6

这是一个ggplot2可以根据需要更改的简单想法,我正在使用此链接中的数据:

#data
confusionMatrix(iris$Species, sample(iris$Species))
newPrior <- c(.05, .8, .15)
names(newPrior) <- levels(iris$Species)

cm <-confusionMatrix(iris$Species, sample(iris$Species))
Run Code Online (Sandbox Code Playgroud)

现在 cm 是一个混淆矩阵对象,可以取出一些对问题有用的东西:

# extract the confusion matrix values as data.frame
cm_d <- as.data.frame(cm$table)
# confusion matrix statistics as data.frame
cm_st <-data.frame(cm$overall)
# round the values
cm_st$cm.overall <- round(cm_st$cm.overall,2)

# here we also have the rounded percentage values
cm_p <- as.data.frame(prop.table(cm$table))
cm_d$Perc <- round(cm_p$Freq*100,2)
Run Code Online (Sandbox Code Playgroud)

现在我们准备好绘制:

library(ggplot2)     # to plot
library(gridExtra)   # to put more
library(grid)        # plot together

# plotting the matrix
cm_d_p <-  ggplot(data = cm_d, aes(x = Prediction , y =  Reference, fill = Freq))+
  geom_tile() +
  geom_text(aes(label = paste("",Freq,",",Perc,"%")), color = 'red', size = 8) +
  theme_light() +
  guides(fill=FALSE) 

# plotting the stats
cm_st_p <-  tableGrob(cm_st)

# all together
grid.arrange(cm_d_p, cm_st_p,nrow = 1, ncol = 2, 
             top=textGrob("Confusion Matrix and Statistics",gp=gpar(fontsize=25,font=1)))
Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明