如何在matplotlib中为条形图添加组标签?

Var*_*cus 28 python matplotlib bar-chart

我想使用matplotlib的条形图功能绘制以下形式的数据:

data = {'Room A':
           {'Shelf 1':
               {'Milk': 10,
                'Water': 20},
            'Shelf 2':
               {'Sugar': 5,
                'Honey': 6}
           },
        'Room B':
           {'Shelf 1':
               {'Wheat': 4,
                'Corn': 7},
            'Shelf 2':
               {'Chicken': 2,
                'Cow': 1}
           }
       }
Run Code Online (Sandbox Code Playgroud)

条形图应该看起来像这样.应从x轴上的标签可见条形图组.有没有办法用matplotlib做到这一点?

Var*_*cus 46

由于我在matplotlib中找不到内置的解决方案,我自己编写了代码:

#!/usr/bin/env python

from matplotlib import pyplot as plt

def mk_groups(data):
    try:
        newdata = data.items()
    except:
        return

    thisgroup = []
    groups = []
    for key, value in newdata:
        newgroups = mk_groups(value)
        if newgroups is None:
            thisgroup.append((key, value))
        else:
            thisgroup.append((key, len(newgroups[-1])))
            if groups:
                groups = [g + n for n, g in zip(newgroups, groups)]
            else:
                groups = newgroups
    return [thisgroup] + groups

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='black')
    line.set_clip_on(False)
    ax.add_line(line)

def label_group_bar(ax, data):
    groups = mk_groups(data)
    xy = groups.pop()
    x, y = zip(*xy)
    ly = len(y)
    xticks = range(1, ly + 1)

    ax.bar(xticks, y, align='center')
    ax.set_xticks(xticks)
    ax.set_xticklabels(x)
    ax.set_xlim(.5, ly + .5)
    ax.yaxis.grid(True)

    scale = 1. / ly
    for pos in xrange(ly + 1):
        add_line(ax, pos * scale, -.1)
    ypos = -.2
    while groups:
        group = groups.pop()
        pos = 0
        for label, rpos in group:
            lxpos = (pos + .5 * rpos) * scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos * scale, ypos)
            pos += rpos
        add_line(ax, pos * scale, ypos)
        ypos -= .1

if __name__ == '__main__':
    data = {'Room A':
               {'Shelf 1':
                   {'Milk': 10,
                    'Water': 20},
                'Shelf 2':
                   {'Sugar': 5,
                    'Honey': 6}
               },
            'Room B':
               {'Shelf 1':
                   {'Wheat': 4,
                    'Corn': 7},
                'Shelf 2':
                   {'Chicken': 2,
                    'Cow': 1}
               }
           }
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    label_group_bar(ax, data)
    fig.subplots_adjust(bottom=0.3)
    fig.savefig('label_group_bar_example.png')
Run Code Online (Sandbox Code Playgroud)

"mk_groups"函数接受一个字典(或任何带有items()方法的东西,如collections.OrderedDict),并将其转换为数据格式,然后用于创建图表.它基本上是一个表单列表:

[ [(label, bars_to_span), ...], ..., [(tick_label, bar_value), ...] ]
Run Code Online (Sandbox Code Playgroud)

"add_line"函数在指定位置(轴坐标)的子图中创建一条垂直线.

"label_group_bar"函数使用字典并在子图中创建条形图,标签位于下方.然后,示例的结果如下所示.

更加容易或更好的解决方案和建议仍然非常受欢迎.

条形图与组

  • 请接受你自己的答案.如果你感觉雄心勃勃,这可能是mpl适当的一个很好的补充(我怀疑这可能会被推广一点并折叠起来的树状图). (2认同)
  • 没问题.我很认真地把它变成mpl正确的(我是一个开发者,并且会帮助这一点通过). (2认同)
  • 如果您使用的是 Python 3,则 xrange 已重命名为 range。所以使用范围而不是 xrange。 (2认同)

小智 13

我一直在寻找这个解决方案.我修改了一些以使用pandas数据表.只有公平分享.

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from itertools import groupby

def test_table():
    data_table = pd.DataFrame({'Room':['Room A']*4 + ['Room B']*4,
                               'Shelf':(['Shelf 1']*2 + ['Shelf 2']*2)*2,
                               'Staple':['Milk','Water','Sugar','Honey','Wheat','Corn','Chicken','Cow'],
                               'Quantity':[10,20,5,6,4,7,2,1],
                               'Ordered':np.random.randint(0,10,8)
                               })
    return data_table

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='black')
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

def label_group_bar_table(ax, df):
    ypos = -.1
    scale = 1./df.index.size
    for level in range(df.index.nlevels)[::-1]:
        pos = 0
        for label, rpos in label_len(df.index,level):
            lxpos = (pos + .5 * rpos)*scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos*scale, ypos)
            pos += rpos
        add_line(ax, pos*scale , ypos)
        ypos -= .1

df = test_table().groupby(['Room','Shelf','Staple']).sum()
fig = plt.figure()
ax = fig.add_subplot(111)
df.plot(kind='bar',stacked=True,ax=fig.gca())
#Below 3 lines remove default labels
labels = ['' for item in ax.get_xticklabels()]
ax.set_xticklabels(labels)
ax.set_xlabel('')
label_group_bar_table(ax, df)
fig.subplots_adjust(bottom=.1*df.index.nlevels)
plt.show()
Run Code Online (Sandbox Code Playgroud)