小编fra*_*bit的帖子

是否可以在不首先在本地持久化的情况下从 GCS 存储桶 URL 加载预训练的 Pytorch 模型?

我是在 Google Dataflow 的背景下问这个问题的,但也是一般的。

使用 PyTorch,我可以引用包含多个文件的本地目录,这些文件构成一个预训练模型。我碰巧使用的是 Roberta 模型,但其他人的界面是一样的。

ls some-directory/
      added_tokens.json
      config.json             
      merges.txt              
      pytorch_model.bin       
      special_tokens_map.json vocab.json
Run Code Online (Sandbox Code Playgroud)
ls some-directory/
      added_tokens.json
      config.json             
      merges.txt              
      pytorch_model.bin       
      special_tokens_map.json vocab.json
Run Code Online (Sandbox Code Playgroud)

但是,我的预训练模型存储在 GCS 存储桶中。让我们称之为gs://my-bucket/roberta/

在 Google Dataflow 中加载这个模型的上下文中,我试图保持无状态并避免持久化到磁盘,所以我更喜欢直接从 GCS 获取这个模型。据我了解,PyTorch 通用接口方法from_pretrained()可以采用本地目录或 URL 的字符串表示形式。但是,我似乎无法从 GCS URL 加载模型。

from pytorch_transformers import RobertaModel

# this works
model = RobertaModel.from_pretrained('/path/to/some-directory/')
Run Code Online (Sandbox Code Playgroud)

如果我尝试使用目录 blob 的公共 https URL,它也会失败,尽管这可能是由于缺乏身份验证,因为在可以创建客户端的 python 环境中引用的凭据不会转换为公共请求https://storage.googleapis

# this fails, probably due to auth
bucket = gcs_client.get_bucket('my-bucket')
directory_blob = bucket.blob(prefix='roberta')
model = RobertaModel.from_pretrained(directory_blob.public_url)
# …
Run Code Online (Sandbox Code Playgroud)

python google-cloud-storage google-cloud-dataflow pytorch

7
推荐指数
1
解决办法
2850
查看次数