16位浮点MPI_Reduce?

sol*_*les 4 c++ precision parallel-processing mpi

我有一个MPI_Reduce()用于某些通信的分布式应用程序.在精度方面,我们使用16位浮点数(半精度)获得完全准确的结果.

为了加速通信(减少数据移动量),有没有办法调用MPI_Reduce()16位浮点数?


(我查看了MPI文档,但没有看到有关16位浮点数的任何信息.)

Pat*_*ick 5

MPI标准在其内部数据类型中仅定义了32位(MPI_FLOAT)或64位(MPI_DOUBLE)浮点数.

但是,您始终可以创建MPI_Datatype自己的自定义缩减操作.下面的代码粗略地说明了如何做到这一点.由于不清楚你正在使用哪个16位浮点实现,我将简单地称为类型float16_t和加法操作fp16_add().

// define custom reduce operation
void my_fp16_sum(void* invec, void* inoutvec, int *len,
              MPI_Datatype *datatype) {
    // cast invec and inoutvec to your float16 type
    float16_t* in = (float16_t)invec;
    float16_t* inout = (float16_t)inoutvec;
    for (int i = 0; i < *len; ++i) {
        // sum your 16 bit floats
        *inout = fp16_add(*in, *inout);
    }
}

// ...

//  in your code:

// create 2-byte datatype (send raw, un-interpreted bytes)
MPI_Datatype mpi_type_float16;
MPI_Type_contiguous(2, MPI_BYTE, &mpi_type_float16);
MPI_Type_commit(&mpi_type_float16);

// create user op (pass function pointer to your user function)
MPI_Op mpi_fp16sum;
MPI_Op_create(&my_fp16_sum, 1, &mpi_fp16sum);

// call MPI_Reduce using your custom reduction operation
MPI_Reduce(&fp16_val, &fp16_result, 1, mpi_type_float16, mpi_fp16sum, 0, MPI_COMM_WORLD);

// clean up (freeing of the custom MPI_Op and MPI_Datatype)
MPI_Type_free(&mpi_type_float16);
MPI_Op_free(&mpi_fp16sum);
Run Code Online (Sandbox Code Playgroud)

  • 你错过了对`MPI_Type_commit()`的调用.还应该可以滥用较新的MPI库中的语言互操作性功能,并通过诸如`mpi_type_float16 = MPI_Type_f2c(MPI_REAL2);`之类的东西来利用Fortran`REAL*2`类型.它仍然需要一个用户定义的简化运算符,因为标准的运算符不能在`MPI_REAL2`上运行. (3认同)