如何将 PyTorch 模型转移到 Apple M1 芯片上的 GPU?

Ani*_*aha 13 metal pytorch apple-m1

2022 年 5 月 18 日,PyTorch宣布支持Mac 上的 GPU 加速 PyTorch 训练。

我按照以下过程在我的 Macbook Air M1 上设置 PyTorch(使用 miniconda)。

conda create -n torch-nightly python=3.8 

$ conda activate torch-nightly

$ pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
Run Code Online (Sandbox Code Playgroud)

我正在尝试执行此处提供的 Udacity 深度学习课程中的脚本。

该脚本使用以下代码将模型移动到 GPU:

conda create -n torch-nightly python=3.8 

$ conda activate torch-nightly

$ pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
Run Code Online (Sandbox Code Playgroud)

然而,这不适用于 M1 芯片,因为没有 CUDA。

如果我们想将模型迁移到 M1 GPU,将张量迁移到 M1 GPU,并完全在 M1 GPU 上进行训练,我们应该做什么?


如果相关:GD是 GAN 的判别器和生成器。

G.cuda()
D.cuda()
Run Code Online (Sandbox Code Playgroud)

Ani*_*aha 30

这是我用的:

if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    G.to(mps_device)
    D.to(mps_device)
Run Code Online (Sandbox Code Playgroud)

同样,对于我想要转移到 M1 GPU 的所有张量,我使用了:

tensor_ = tensor_(mps_device)
Run Code Online (Sandbox Code Playgroud)

有些操作尚未使用 MPS 实现,我们可能需要设置一些环境变量来使用 CPU 回退:我在执行脚本期间遇到的一个错误是

# NotImplementedError: The operator 'aten::_slow_conv2d_forward' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
Run Code Online (Sandbox Code Playgroud)

为了解决这个问题我设置了环境变量PYTORCH_ENABLE_MPS_FALLBACK=1

conda env config vars set PYTORCH_ENABLE_MPS_FALLBACK=1
conda activate <test-env>
Run Code Online (Sandbox Code Playgroud)

参考:

  1. https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/
  2. https://pytorch.org/docs/master/notes/mps.html
  3. https://sebastianraschka.com/blog/2022/pytorch-m1-gpu.html
  4. https://sebastianraschka.com/blog/2022/pytorch-m1-gpu.html
  5. https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#setting-environment-variables