为什么 Java 矢量 API 与标量相比如此慢?

Red*_*mpt 3 java simd vectorization

我最近决定尝试一下 Java 的新孵化矢量 API,看看它能达到多快。我实现了两种相当简单的方法,一种用于解析 int,另一种用于查找字符串中字符的索引。在这两种情况下,与标量方法相比,我的矢量化方法都慢得令人难以置信。

这是我的代码:

public class SIMDParse {

private static IntVector mul = IntVector.fromArray(
        IntVector.SPECIES_512,
        new int[] {0, 0, 0, 0, 0, 0, 1000000000, 100000000, 10000000, 1000000, 100000, 10000, 1000, 100, 10, 1},
        0
);
private static byte zeroChar = (byte) '0';
private static int width = IntVector.SPECIES_512.length();
private static byte[] filler;

static {
    filler = new byte[16];
    for (int i = 0; i < 16; i++) {
        filler[i] = zeroChar;
    }
}

public static int parseInt(String str) {
    boolean negative = str.charAt(0) == '-';
    byte[] bytes = str.getBytes(StandardCharsets.UTF_8);
    if (negative) {
        bytes[0] = zeroChar;
    }
    bytes = ensureSize(bytes, width);
    ByteVector vec = ByteVector.fromArray(ByteVector.SPECIES_128, bytes, 0);
    vec = vec.sub(zeroChar);
    IntVector ints = (IntVector) vec.castShape(IntVector.SPECIES_512, 0);
    ints = ints.mul(mul);
    return ints.reduceLanes(VectorOperators.ADD) * (negative ? -1 : 1);
}

public static byte[] ensureSize(byte[] arr, int per) {
    int mod = arr.length % per;
    if (mod == 0) {
        return arr;
    }
    int length = arr.length - (mod);
    length += per;
    byte[] newArr = new byte[length];
    System.arraycopy(arr, 0, newArr, per - mod, arr.length);
    System.arraycopy(filler, 0, newArr, 0, per - mod);
    return newArr;
}

public static byte[] ensureSize2(byte[] arr, int per) {
    int mod = arr.length % per;
    if (mod == 0) {
        return arr;
    }
    int length = arr.length - (mod);
    length += per;
    byte[] newArr = new byte[length];
    System.arraycopy(arr, 0, newArr, 0, arr.length);
    return newArr;
}

public static int indexOf(String s, char c) {
    byte[] b = s.getBytes(StandardCharsets.UTF_8);
    int width = ByteVector.SPECIES_MAX.length();
    byte bChar = (byte) c;
    b = ensureSize2(b, width);
    for (int i = 0; i < b.length; i += width) {
        ByteVector vec = ByteVector.fromArray(ByteVector.SPECIES_MAX, b, i);
        int pos = vec.compare(VectorOperators.EQ, bChar).firstTrue();
        if (pos != width) {
            return pos + i;
        }
    }
    return -1;
}

}
Run Code Online (Sandbox Code Playgroud)

我完全预计我的 int 解析会更慢,因为它永远不会处理超过向量大小可以容纳的内容(int 的长度永远不会超过 10 位数字)。

根据我的基准,解析123为 int 10k 次需要 3081 微秒Integer.parseInt,而我的实现需要 80601 微秒。'a'在很长的字符串 ( ) 中搜索"____".repeat(4000) + "a" + "----".repeat(193)需要 7709 微秒,即String#indexOf7。

为什么它慢得令人难以置信?我认为 SIMD 的全部要点在于,对于此类任务,它比标量等效项更快。

Pet*_*des 5

您选择了 SIMD 不擅长的东西(字符串->int),以及 JVM 非常擅长优化循环外的东西。如果输入不是向量宽度的精确倍数,则您会通过大量额外的复制工作来实现。


我假设您的时间是总计(每次重复 10k 次),而不是每次调用的平均值。

7 我们对于这个速度来说已经快得不可思议了。

"____".repeat(4000)之前有 16k 字符(32k 字节)'a',我认为这就是您要搜索的内容。即使是在 4GHz CPU 上以每个时钟周期 2x 32 字节向量运行的经过良好调整/展开wmemchr(又名 indexOf)的 10k 次重复也需要 1250 us。( 32000B / (64B/c) * 10000 reps / 4000 MHz),假设 32kB 字符串在 32KiB L1d 缓存中保持热状态。

我希望并期望 JVM 要么调用本机wmemchr,要么使用对常用核心库函数(如String#indexOf. 例如,glibc 的 avx2 memchr对循环展开进行了很好的调整。(Java 字符串实际上是 UTF-16,但 Linux 上的 Cwchar_t是 4 字节宽,与 Windows 不同,因此 JVM 需要自己的实现。)

内置字符串indexOf也是 JIT“了解”的内容。 当它看到您重复使用相同的字符串作为输入时,它显然能够将其从循环中提升出来。(但是接下来这 7 个我们在做什么?我猜做一个不太好的事情memchr,然后在 1/时钟执行一个空的 10k 迭代循环可能需要大约 7 微秒,特别是如果你的 CPU 不是快至 4GHz。)

看看绩效评估的惯用方式?- 如果将重复计数加倍到 20k 并没有使时间加倍,则您的基准测试已被破坏,并且没有衡量您认为的效果。


但你说 7 us 是每次迭代时间? 除了未优化的第一遍之外,这可能会慢得令人难以置信。因此,这可能是基准测试方法错误的迹象,例如缺乏热身运行。

如果IndexOf一次检查一个字符,16k * 0.25 ns/char则需要 4000 纳秒,或者在 4GHz CPU 上需要 4 微秒。7 us每个周期大约检查 1 个字符,这在现代 x86 上慢得可怜。我认为主流 JVM 在完成 JIT 优化后不太可能使用如此缓慢的实现。


您的手动 SIMD indexOf 不太可能在循环外得到优化。 如果大小不是向量宽度的精确倍数,它每次都会复制整个数组!(在ensureSize2)。正常的技术是回退到最后一个size % width元素的标量,这对于大型数组来说显然要好得多。或者更好的是,对于与先前工作重叠不成问题的内容,执行在数组末尾结束的未对齐加载(如果总大小 >= 向量宽度)。

现代 x86 上的一个不错的 memchr(使用像你的 indexOf 这样的算法而不展开)应该每 1.5 个时钟周期大约 1 个向量(16/32/64 字节),L1d 缓存中的数据很热,没有循环展开或任何东西。(检查向量比较和指针绑定作为可能的循环退出条件需要额外的 asm 指令,而不是简单的strlen,但请参阅此答案以获取假设对齐缓冲区的简单手写 strlen 的一些微基准)。对于像 Skylake 这样的 CPU,其管道宽度为 4 uop/时钟,您的循环可能会indexOf成为前端吞吐量的瓶颈。

因此,假设您使用的 CPU 没有 AVX2,那么您的实现每 16 字节向量需要 1.5 个周期?你没说。

16kB / 16B = 1000 个向量。如果每 1.5 个时钟 1 个向量,则为 1500 个周期。在 3GHz 机器上,1500 个周期每次调用需要 500 ns = 0.5 us,或者每 10k 次重复需要 5000 us。但由于 16194 字节不是 16 的倍数,因此您每次调用时都会复制整个内容,因此这会花费更多时间,并且可能会占您 7709 us 的总时间。


SIMD 有什么好处

对于这样的任务。

不,“水平”之类的东西ints.reduceLanesSIMD 通常速度很慢。 甚至像如何使用 SIMD 实现 atoi?使用 x86pmaddwd进行水平乘法和加法,仍然需要大量工作。

请注意,为了使元素足够宽以乘以位值而不会溢出,您必须解包,这会花费一些洗牌成本。 ints.reduceLanes大约需要 log2(elements) 洗牌/添加步骤,如果您从 的 512 位 AVX-512 向量开始int,则这些洗牌中的前 2 个是车道交叉,3 个周期延迟 ( https://agner.org/优化/)。(或者如果你的机器甚至没有 AVX2,那么 512 位整数向量实际上是 4x 128 位向量。并且你必须做单独的工作来解压每个部分。但至少减少会很便宜,只是垂直相加,直到得到一个 128 位向量。)