如何转换 float64 使其在苹果硅中工作?

Ahs*_*que 7 python metal pytorch apple-silicon apple-m1

我正在尝试将预先训练的权重加载到mpsApple M1 的 GPU 设备。为了最大限度地重现该问题,我可以运行以下命令:

torch.load('yolov7_training.pt', map_location='mps')
Run Code Online (Sandbox Code Playgroud)

这会产生以下异常:

  File "train.py", line 619, in <module>
    train(hyp, opt, device, tb_writer)
  File "train.py", line 72, in train
    torch.load('yolov7_training.pt', map_location='mps')
  File "/Users/smahasanulhaque/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/serialization.py", line 789, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/Users/smahasanulhaque/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/serialization.py", line 1131, in _load
    result = unpickler.load()
  File "/Users/smahasanulhaque/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/_utils.py", line 153, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "/Users/smahasanulhaque/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/_utils.py", line 146, in _rebuild_tensor
    t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
(torch-gpu)
Run Code Online (Sandbox Code Playgroud)

我是 pytorch 的初学者,并且没有看到在加载时将其转换为 float32 的选项,正如异常所建议的那样。如何使这项工作有效?

我愚蠢的解决方法是将其加载到 CPU 中,使其成为 float32,然后加载到 mps 设备。但不确定如何做到这一点,或者它是否会起作用。