我是在 Google Dataflow 的背景下问这个问题的,但也是一般的。
使用 PyTorch,我可以引用包含多个文件的本地目录,这些文件构成一个预训练模型。我碰巧使用的是 Roberta 模型,但其他人的界面是一样的。
ls some-directory/
added_tokens.json
config.json
merges.txt
pytorch_model.bin
special_tokens_map.json vocab.json
Run Code Online (Sandbox Code Playgroud)
ls some-directory/
added_tokens.json
config.json
merges.txt
pytorch_model.bin
special_tokens_map.json vocab.json
Run Code Online (Sandbox Code Playgroud)
但是,我的预训练模型存储在 GCS 存储桶中。让我们称之为gs://my-bucket/roberta/
。
在 Google Dataflow 中加载这个模型的上下文中,我试图保持无状态并避免持久化到磁盘,所以我更喜欢直接从 GCS 获取这个模型。据我了解,PyTorch 通用接口方法from_pretrained()
可以采用本地目录或 URL 的字符串表示形式。但是,我似乎无法从 GCS URL 加载模型。
from pytorch_transformers import RobertaModel
# this works
model = RobertaModel.from_pretrained('/path/to/some-directory/')
Run Code Online (Sandbox Code Playgroud)
如果我尝试使用目录 blob 的公共 https URL,它也会失败,尽管这可能是由于缺乏身份验证,因为在可以创建客户端的 python 环境中引用的凭据不会转换为公共请求https://storage.googleapis
# this fails, probably due to auth
bucket = gcs_client.get_bucket('my-bucket')
directory_blob = bucket.blob(prefix='roberta')
model = RobertaModel.from_pretrained(directory_blob.public_url)
# …
Run Code Online (Sandbox Code Playgroud)