如何在seaborn stripplot中设置抖动种子?

Hau*_*son 6 python matplotlib random-seed seaborn jitter

我正在尝试准确地重现带状图,以便我可以可靠地绘制线条并在其上书写。然而,当我生成带有抖动的带状图时,抖动是随机的,并阻止我实现目标。

我盲目地尝试了rcParams在其他 Stack Overflow 帖子中找到的一些方法,例如mpl.rcParams['svg.hashsalt']没有起作用的方法。我也尝试设置种子但random.seed()没有成功。

我正在运行的代码如下所示。

import seaborn as sns
import matplotlib.pyplot as plt
import random

plt.figure(figsize=(14,9))

random.seed(123)

catagories = []
values = []

for i in range(0,200):
    n = random.randint(1,3)
    catagories.append(n)

for i in range(0,200):
    n = random.randint(1,100)
    values.append(n)

sns.stripplot(catagories, values, size=5)
plt.title('Random Jitter')
plt.xticks([0,1,2],[1,2,3])
plt.show()
Run Code Online (Sandbox Code Playgroud)

这段代码生成了一个stripplot就像我想要的那样。但是,如果您运行代码两次,由于抖动,您将得到不同的点位置。我正在制作的图表需要抖动才能看起来不荒谬,但我想在图表上写字。然而,在运行代码之前无法知道这些点的确切位置,并且每次运行代码时这些点都会发生变化。

有没有办法在seaborn中设置抖动种子,stripplots使它们完美再现?

Tre*_*ney 7

  • 抖动由下式决定scipy.stats.uniform
  • uniformclass uniform_gen(scipy.stats._distn_infrastructure.rv_continuous)
  • 哪个是子类class rv_continuous(rv_generic)
  • 其中有一个seed参数,并使用np.random
  • 因此,使用np.random.seed()
    • 需要在每个绘图之前调用它。在示例的情况下,np.random.seed(123)必须位于循环内部。

来自 Stripplot 文档字符串

jitter : float, ``True``/``1`` is special-cased, optional
    Amount of jitter (only along the categorical axis) to apply. This
    can be useful when you have many points and they overlap, so that
    it is easier to see the distribution. You can specify the amount
    of jitter (half the width of the uniform random variable support),
    or just use ``True`` for a good default.
Run Code Online (Sandbox Code Playgroud)

来自class _StripPlottercategorical.py

  • 抖动计算公式为scipy.stats.uniform
from scipy import stats

class _StripPlotter(_CategoricalScatterPlotter):
    """1-d scatterplot with categorical organization."""
    def __init__(self, x, y, hue, data, order, hue_order,
                 jitter, dodge, orient, color, palette):
        """Initialize the plotter."""
        self.establish_variables(x, y, hue, data, orient, order, hue_order)
        self.establish_colors(color, palette, 1)

        # Set object attributes
        self.dodge = dodge
        self.width = .8

        if jitter == 1:  # Use a good default for `jitter = True`
            jlim = 0.1
        else:
            jlim = float(jitter)
        if self.hue_names is not None and dodge:
            jlim /= len(self.hue_names)
        self.jitterer = stats.uniform(-jlim, jlim * 2).rvs
Run Code Online (Sandbox Code Playgroud)

来自 rv_continuous 文档字符串

    seed : {None, int, `~np.random.RandomState`, `~np.random.Generator`}, optional
        This parameter defines the object to use for drawing random variates.
        If `seed` is `None` the `~np.random.RandomState` singleton is used.
        If `seed` is an int, a new ``RandomState`` instance is used, seeded
        with seed.
        If `seed` is already a ``RandomState`` or ``Generator`` instance,
        then that object is used.
        Default is None.
Run Code Online (Sandbox Code Playgroud)

使用您的代码np.random.seed

  • 所有剧情点都是一样的
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 3, figsize=(12, 12))
for x in range(6):

    np.random.seed(123)

    catagories = []
    values = []

    for i in range(0,200):
        n = np.random.randint(1,3)
        catagories.append(n)

    for i in range(0,200):
        n = np.random.randint(1,100)
        values.append(n)

    row = x // 3
    col = x % 3
    axcurr = axes[row, col]

    sns.stripplot(catagories, values, size=5, ax=axcurr)
    axcurr.set_title(f'np.random jitter {x+1}')
plt.show()
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

仅使用random

  • 情节点四处移动
import seaborn as sns
import matplotlib.pyplot as plt
import random

fig, axes = plt.subplots(2, 3, figsize=(12, 12))
for x in range(6):

    random.seed(123)

    catagories = []
    values = []

    for i in range(0,200):
        n = random.randint(1,3)
        catagories.append(n)

    for i in range(0,200):
        n = random.randint(1,100)
        values.append(n)

    row = x // 3
    col = x % 3
    axcurr = axes[row, col]

    sns.stripplot(catagories, values, size=5, ax=axcurr)
    axcurr.set_title(f'random jitter {x+1}')
plt.show()
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

用于random数据和np.random.seed绘图

fig, axes = plt.subplots(2, 3, figsize=(12, 12))
for x in range(6):

    random.seed(123)

    catagories = []
    values = []

    for i in range(0,200):
        n = random.randint(1,3)
        catagories.append(n)

    for i in range(0,200):
        n = random.randint(1,100)
        values.append(n)

    row = x // 3
    col = x % 3
    axcurr = axes[row, col]

    np.random.seed(123)
    sns.stripplot(catagories, values, size=5, ax=axcurr)
    axcurr.set_title(f'np.random jitter {x+1}')
plt.show()
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述