Vah*_*yan 5 gpu compute-shader ios metal
任何人都知道在金属内核中使用随机浮点数计算缓冲区平均值的正确方法吗?
在计算命令编码器上分派工作:
threadsPerGroup = MTLSizeMake(1, 1, inputTexture.arrayLength);
numThreadGroups = MTLSizeMake(1, 1, inputTexture.arrayLength / threadsPerGroup.depth);
[commandEncoder dispatchThreadgroups:numThreadGroups
threadsPerThreadgroup:threadsPerGroup];
Run Code Online (Sandbox Code Playgroud)
内核代码:
kernel void mean(texture2d_array<float, access::read> inTex [[ texture(0) ]],
device float *means [[ buffer(1) ]],
uint3 id [[ thread_position_in_grid ]]) {
if (id.x == 0 && id.y == 0) {
float mean = 0.0;
for (uint i = 0; i < inTex.get_width(); ++i) {
for (uint j = 0; j < inTex.get_height(); ++j) {
mean += inTex.read(uint2(i, j), id.z)[0];
}
}
float textureArea = inTex.get_width() * inTex.get_height();
mean /= textureArea;
out[id.z] = mean;
}
}
Run Code Online (Sandbox Code Playgroud)
缓冲区以texture2d_array 类型的纹理表示,采用R32Float 像素格式。
如果您可以使用 uint 数组(而不是 float)作为数据源,我建议使用“原子获取和修改函数”(如金属着色语言规范中所述)以原子方式写入缓冲区。
下面是一个内核函数的示例,它接受输入缓冲区(数据:Float 数组)并将缓冲区的总和写入原子缓冲区(总和,指向 uint 的指针):
kernel void sum(device uint *data [[ buffer(0) ]],
volatile device atomic_uint *sum [[ buffer(1) ]],
uint gid [[ thread_position_in_grid ]])
{
atomic_fetch_add_explicit(sum, data[gid], memory_order_relaxed);
}
Run Code Online (Sandbox Code Playgroud)
在您的 swift 文件中,您将设置缓冲区:
...
let data: [UInt] = [1, 2, 3, 4]
let dataBuffer = device.makeBuffer(bytes: &data, length: (data.count * MemoryLayout<UInt>.size), options: [])
commandEncoder.setBuffer(dataBuffer, offset: 0, at: 0)
var sum:UInt = 0
let sumBuffer = device!.makeBuffer(bytes: &sum, length: MemoryLayout<UInt>.size, options: [])
commandEncoder.setBuffer(sumBuffer, offset: 0, at: 1)
commandEncoder.endEncoding()
Run Code Online (Sandbox Code Playgroud)
提交,等待,然后从 GPU 获取数据:
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
let nsData = NSData(bytesNoCopy: sumBuffer.contents(),
length: sumBuffer.length,
freeWhenDone: false)
nsData.getBytes(&sum, length:sumBuffer.length)
let mean = Float(sum/data.count)
print(mean)
Run Code Online (Sandbox Code Playgroud)
或者,如果您的初始数据源必须是浮点数组,您可以使用Accelerate 框架的vDSP_meanv方法,该方法对于此类计算非常快。
我希望有帮助,干杯!
| 归档时间: |
|
| 查看次数: |
1866 次 |
| 最近记录: |