如何从本地系统加载 TF hub 模型

Tal*_*war 9 python tensorflow

一种方法是每次从tensorflow_hub如下所示下载模型

import tensorflow as tf
import tensorflow_hub as hub

hub_url = "https://tfhub.dev/google/tf2-preview/nnlm-en-dim128/1"
embed = hub.KerasLayer(hub_url)
embeddings = embed(["A long sentence.", "single-word", "http://example.com"])
print(embeddings.shape, embeddings.dtype)
Run Code Online (Sandbox Code Playgroud)

我想下载一次文件并一次又一次地使用而不是每次都下载

Mat*_*mas 10

  1. 从 url +“?tf-hub-format=compressed”下载模型,
    例如“https://tfhub.dev/google/tf2-preview/nnlm-en-dim128/1?tf-hub-format=compressed”
  2. 解压
  3. 在代码中加载解压的文件夹
import tensorflow as tf
import tensorflow_hub as hub

embed = hub.KerasLayer('path/to/untarred/folder')
embeddings = embed(["A long sentence.", "single-word", "http://example.com"])
print(embeddings.shape, embeddings.dtype)
Run Code Online (Sandbox Code Playgroud)


Shu*_*hal 5

您可以使用该hub.load()方法加载 TF Hub 模块。另外,文档说,

目前,仅 TensorFlow 2.x 以及通过调用创建的模块完全支持此方法tensorflow.saved_model.save()。该方法适用于急切模式和图形模式。

hub.load方法有一个参数handle。模块句柄的类型是,

  1. 智能 URL 解析器,例如 tfhub.dev,例如:https://tfhub.dev/google/nnlm-en-dim128/1

  2. Tensorflow 支持的文件系统上包含模块文件的目录。这可能包括本地目录(例如/usr/local/mymodule)或 Google Cloud Storage 存储桶(gs://mymodule)。

  3. 指向模块 TGZ 存档的 URL,例如https://example.com/mymodule.tar.gz.

您可以使用第二点和第三点。

  • `hub_module = hub.load('/tmp/atory-image-stylization-v1-256/')` 不起作用。难道我做错了什么? (3认同)