使用过去运行的预训练节点 - Pytorch Biggraph

Yeh*_*ens 7 python graph pytorch

在与这个惊人的facebookresearch / PyTorch-BigGraph项目及其不可能的 API苦苦挣扎之后,我设法掌握了如何运行它(感谢独立的简单示例

我的系统限制不允许我训练所有边的密集(嵌入)表示,我需要不时上传过去的嵌入并使用新边和现有节点训练模型,注意过去和新边的节点列表不一定重叠。

我试图从这里了解:请参阅上下文部分如何做到这一点,到目前为止没有成功。

以下是一个独立的 PGD 代码,它变成batch_edges了一个嵌入节点列表,但是,我需要它来使用预训练的节点列表past_trained_nodes

import os
import shutil
from pathlib import Path

from torchbiggraph.config import parse_config
from torchbiggraph.converters.importers import TSVEdgelistReader, convert_input_data
from torchbiggraph.train import train
from torchbiggraph.util import SubprocessInitializer, setup_logging

DIMENSION = 4
DATA_DIR = 'data'
GRAPH_PATH = DATA_DIR + '/output1.tsv'
MODEL_DIR = 'model'


raw_config = dict(
        entity_path=DATA_DIR,
        edge_paths=[DATA_DIR + '/edges_partitioned', ],
        checkpoint_path=MODEL_DIR,
        entities={"n": {"num_partitions": 1}},
        relations=[{"name": "doesnt_matter", "lhs": "n", "rhs": "n", "operator": "complex_diagonal", }],
        dynamic_relations=False, dimension=DIMENSION, global_emb=False, comparator="dot",
        num_epochs=7, num_uniform_negs=1000, loss_fn="softmax", lr=0.1, eval_fraction=0.,)

batch_edges = [["A", "B"], ["B", "C"],  ["C", "D"], ["D", "B"], ["B", "D"]]


# I want the model to use these pretrained nodes, Notice that Node A exist, And F Does not
#I dont have all past nodes, as some are gained from data
past_trained_nodes = {'A': [0.5, 0.3, 1.5, 8.1], 'F': [3, 0.6, 1.2, 4.3]}


try:
    shutil.rmtree('data')
except:
    pass
try:
    shutil.rmtree(MODEL_DIR)
except:
    pass

os.makedirs(DATA_DIR, exist_ok=True)
with open(GRAPH_PATH, 'w') as f:
    for edge in batch_edges:
        f.write('\t'.join(edge) + '\n')

setup_logging()
config = parse_config(raw_config)
subprocess_init = SubprocessInitializer()
input_edge_paths = [Path(GRAPH_PATH)]

convert_input_data(config.entities, config.relations, config.entity_path, config.edge_paths,
                   input_edge_paths, TSVEdgelistReader(lhs_col=0, rel_col=None, rhs_col=1),
                   dynamic_relations=config.dynamic_relations, )

train(config, subprocess_init=subprocess_init)
Run Code Online (Sandbox Code Playgroud)

如何在当前模型中使用我的预训练节点?

提前致谢!

meT*_*sky 4

由于torchbiggraph是基于文件的,您可以修改保存的文件以加载预先训练的嵌入并添加新节点。我写了一个函数来实现这个

import json
    
def pretrained_and_new_nodes(pretrained_nodes,new_nodes,entity_name,data_dir,embeddings_path):
    """
    pretrained_nodes: 
        A dictionary of nodes and their embeddings 
    new_nodes:
        A list of new nodes,each new node must have an embedding in pretrained_nodes. 
        If no new nodes, use []
    entity_name: 
        The entity's name, for example, WHATEVER_0
    data_dir: 
        The path to the files that record graph nodes and edges 
    embeddings_path: 
        The path to the .h5 file of embeddings
    """
    with open('%s/entity_names_%s.json' % (data_dir,entity_name),'r') as source:
        nodes = json.load(source)
    dist = {item:ind for ind,item in enumerate(nodes)}
    
    if len(new_nodes) > 0:
        # modify both the node names and the node count 
        extended = nodes.copy()
        extended.extend(new_nodes)
        with open('%s/entity_names_%s.json' % (data_dir,entity_name),'w') as source:
            json.dump(extended,source)
        with open('%s/entity_count_%s.txt' % (data_dir,entity_name),'w') as source:
            source.write('%i' % len(extended))
    
    if len(new_nodes) == 0:
        # if no new nodes are added, we won't bother create a new .h5 file, but just modify the original one 
        with h5py.File(embeddings_path,'r+') as source:

            for node,embedding in pretrained_nodes.items():
                if node in nodes:
                    source['embeddings'][dist[node]] = embedding
    else:
        # if there are new nodes, then we must create a new .h5 file 
        # see /sf/answers/3295218181/ 
        with h5py.File(embeddings_path,'r+') as source:
            embeddings = list(source['embeddings'])
            optimizer = list(source['optimizer'])
        for node,embedding in pretrained_nodes.items():
            if node in nodes:
                embeddings[dist[node]] = embedding
        # append new nodes in order 
        for node in new_nodes:
            if node not in list(pretrained_nodes.keys()):
                raise ValueError 
            else:
                embeddings.append(pretrained_nodes[node])
        # write a new .h5 file for the embedding 
        with h5py.File(embeddings_path,'w') as source:
            source.create_dataset('embeddings',data=embeddings,)
            optimizer = [item.encode('ascii') for item in optimizer]
            source.create_dataset('optimizer',data=optimizer)
Run Code Online (Sandbox Code Playgroud)

训练模型后(假设您在帖子中链接的简单示例),并且您希望将学习到的节点嵌入更改A[0.5, 0.3, 1.5, 8.1]. F此外,您还想通过嵌入向图中添加一个新节点[3, 0.6, 1.2, 4.3](这个新添加的节点F与其他节点没有连接)。你可以运行我的函数

past_trained_nodes = {'A': [0.5, 0.3, 1.5, 8.1], 'F': [3, 0.6, 1.2, 4.3]}
pretrained_and_new_nodes(pretrained_nodes=past_trained_nodes,
                         new_nodes=['F'],
                         entity_name='WHATEVER_0',
                         data_dir='data/example_1',
                         embeddings_path='model_1/embeddings_WHATEVER_0.v7.h5')
Run Code Online (Sandbox Code Playgroud)

运行此函数后,您可以检查嵌入的修改文件embeddings_WHATEVER_0.v7.h5

filename = "model_1/embeddings_WHATEVER_0.v7.h5" 

with h5py.File(filename, "r") as source:
    embeddings = list(source['embeddings'])

embeddings
Run Code Online (Sandbox Code Playgroud)

你会看到, 的嵌入A发生了变化,并且还添加了 的嵌入F(嵌入的顺序与 中的节点顺序一致entity_names_WHATEVER_0.json)。

修改文件后,您可以在新的训练会话中使用预先训练的节点。