将 BERT 模型转换为 TFLite

Ali*_*mon 5 python tensorflow tensorflow-lite bert-language-model

我有使用预训练的 bert 模型构建的语义搜索引擎的代码。我想将此模型转换为 tflite,以便将其部署到 google mlkit。我想知道如何转换它。我想知道是否有可能将其转换为 tflite。这可能是因为它在官方tensorflow网站上提到: https: //www.tensorflow.org/lite/convert。但我不知道从哪里开始

代码:


from sentence_transformers import SentenceTransformer

# Load the BERT model. Various models trained on Natural Language Inference (NLI) https://github.com/UKPLab/sentence-transformers/blob/master/docs/pretrained-models/nli-models.md and 
# Semantic Textual Similarity are available https://github.com/UKPLab/sentence-transformers/blob/master/docs/pretrained-models/sts-models.md

model = SentenceTransformer('bert-base-nli-mean-tokens')

# A corpus is a list with documents split by sentences.

sentences = ['Absence of sanity', 
             'Lack of saneness',
             'A man is eating food.',
             'A man is eating a piece of bread.',
             'The girl is carrying a baby.',
             'A man is riding a horse.',
             'A woman is playing violin.',
             'Two men pushed carts through the woods.',
             'A man is riding a white horse on an enclosed ground.',
             'A monkey is playing drums.',
             'A cheetah is running behind its prey.']

# Each sentence is encoded as a 1-D vector with 78 columns
sentence_embeddings = model.encode(sentences)

print('Sample BERT embedding vector - length', len(sentence_embeddings[0]))

print('Sample BERT embedding vector - note includes negative values', sentence_embeddings[0])

#@title Sematic Search Form

# code adapted from https://github.com/UKPLab/sentence-transformers/blob/master/examples/application_semantic_search.py

query = 'Nobody has sane thoughts' #@param {type: 'string'}

queries = [query]
query_embeddings = model.encode(queries)

# Find the closest 3 sentences of the corpus for each query sentence based on cosine similarity
number_top_matches = 3 #@param {type: "number"}

print("Semantic Search Results")

for query, query_embedding in zip(queries, query_embeddings):
    distances = scipy.spatial.distance.cdist([query_embedding], sentence_embeddings, "cosine")[0]

    results = zip(range(len(distances)), distances)
    results = sorted(results, key=lambda x: x[1])

    print("\n\n======================\n\n")
    print("Query:", query)
    print("\nTop 5 most similar sentences in corpus:")

    for idx, distance in results[0:number_top_matches]:
        print(sentences[idx].strip(), "(Cosine Score: %.4f)" % (1-distance))
Run Code Online (Sandbox Code Playgroud)

Jin*_*ich 0

首先,您需要在 TensorFlow 中建立模型,您使用的包是用 PyTorch 编写的。Huggingface 的Transformers拥有 TensorFlow 模型,您可以开始使用。此外,他们还有适用于 Android 的TFLite 就绪模型

一般来说,您首先有一个 TensorFlow 模型。他们,将其保存为以下SavedModel格式:

tf.saved_model.save(pretrained_model, "/tmp/pretrained-bert/1/")
Run Code Online (Sandbox Code Playgroud)

您可以在此运行转换器。