Source code for pykg2vec.utils.trainer

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import warnings
import torch
import torch.optim as optim
import numpy as np
import pandas as pd

from tqdm import tqdm
from pathlib import Path
from pykg2vec.utils.evaluator import Evaluator
from pykg2vec.utils.visualization import Visualization
from pykg2vec.utils.riemannian_optimizer import RiemannianOptimizer
from pykg2vec.data.generator import Generator
from pykg2vec.utils.logger import Logger
from pykg2vec.common import Importer, Monitor, TrainingStrategy
warnings.filterwarnings('ignore')


[docs]class EarlyStopper: """ Class used by trainer for handling the early stopping mechanism during the training of KGE algorithms. Args: patience (int): Number of epochs to wait before early stopping the training on no improvement. No early stopping if it is a negative number (default: {-1}). monitor (Monitor): the type of metric that earlystopper will monitor. """ _logger = Logger().get_logger(__name__) def __init__(self, patience, monitor): self.monitor = monitor self.patience = patience # controlling variables. self.previous_metrics = None self.patience_left = patience def should_stop(self, curr_metrics): should_stop = False value, name = self.monitor.value, self.monitor.name if self.previous_metrics is not None: if self.monitor == Monitor.MEAN_RANK or self.monitor == Monitor.FILTERED_MEAN_RANK: is_worse = self.previous_metrics[value] < curr_metrics[value] else: is_worse = self.previous_metrics[value] > curr_metrics[value] if self.patience_left > 0 and is_worse: self.patience_left -= 1 self._logger.info( '%s more chances before the trainer stops the training. (prev_%s, curr_%s): (%.4f, %.4f)' % (self.patience_left, name, name, self.previous_metrics[value], curr_metrics[value])) elif self.patience_left == 0 and is_worse: self._logger.info('Stop the training.') should_stop = True else: self._logger.info('Reset the patience count to %d' % (self.patience)) self.patience_left = self.patience self.previous_metrics = curr_metrics return should_stop
[docs]class Trainer: """ Class for handling the training of the algorithms. Args: model (object): KGE model object Examples: >>> from pykg2vec.utils.trainer import Trainer >>> from pykg2vec.models.pairwise import TransE >>> trainer = Trainer(TransE()) >>> trainer.build_model() >>> trainer.train_model() """ TRAINED_MODEL_FILE_NAME = "model.vec.pt" TRAINED_MODEL_CONFIG_NAME = "config.npy" _logger = Logger().get_logger(__name__) def __init__(self, model, config): self.model = model self.config = config self.best_metric = None self.monitor = None self.training_results = [] self.evaluator = None self.generator = None self.optimizer = None self.early_stopper = None
[docs] def build_model(self, monitor=Monitor.FILTERED_MEAN_RANK): """function to build the model""" if self.config.load_from_data is not None: self.load_model(self.config.load_from_data) self.evaluator = Evaluator(self.model, self.config) self.model.to(self.config.device) if self.config.optimizer == "adam": self.optimizer = optim.Adam( self.model.parameters(), lr=self.config.learning_rate, ) elif self.config.optimizer == "sgd": self.optimizer = optim.SGD( self.model.parameters(), lr=self.config.learning_rate, ) elif self.config.optimizer == "adagrad": self.optimizer = optim.Adagrad( self.model.parameters(), lr=self.config.learning_rate, ) elif self.config.optimizer == "rms": self.optimizer = optim.RMSprop( self.model.parameters(), lr=self.config.learning_rate, ) elif self.config.optimizer == "riemannian": param_names = [name for name, param in self.model.named_parameters()] self.optimizer = RiemannianOptimizer( self.model.parameters(), lr=self.config.learning_rate, param_names=param_names ) else: raise NotImplementedError("No support for %s optimizer" % self.config.optimizer) self.config.summary() self.early_stopper = EarlyStopper(self.config.patience, monitor)
# Training related functions: def train_step_pairwise(self, pos_h, pos_r, pos_t, neg_h, neg_r, neg_t): pos_preds = self.model(pos_h, pos_r, pos_t) neg_preds = self.model(neg_h, neg_r, neg_t) if self.model.model_name.lower() == "rotate": loss = self.model.loss(pos_preds, neg_preds, self.config.neg_rate, self.config.alpha) else: loss = self.model.loss(pos_preds, neg_preds, self.config.margin) loss += self.model.get_reg(None, None, None) return loss def train_step_projection(self, h, r, t, hr_t, tr_h): if self.model.model_name.lower() in ["conve", "tucker", "interacte", "hyper", "acre"]: pred_tails = self.model(h, r, direction="tail") # (h, r) -> hr_t forward pred_heads = self.model(t, r, direction="head") # (t, r) -> tr_h backward if hasattr(self.config, 'label_smoothing'): loss = self.model.loss(pred_heads, pred_tails, tr_h, hr_t, self.config.label_smoothing, self.config.tot_entity) else: loss = self.model.loss(pred_heads, pred_tails, tr_h, hr_t, None, None) else: pred_tails = self.model(h, r, hr_t, direction="tail") # (h, r) -> hr_t forward pred_heads = self.model(t, r, tr_h, direction="head") # (t, r) -> tr_h backward loss = self.model.loss(pred_heads, pred_tails) loss += self.model.get_reg(h, r, t) return loss def train_step_pointwise(self, h, r, t, target): preds = self.model(h, r, t) loss = self.model.loss(preds, target.type(preds.type())) loss += self.model.get_reg(h, r, t) return loss
[docs] def train_model(self): # for key, value in self.config.__dict__.items(): # print(key," ",value) #print(self.config.__dict__[""]) # pdb.set_trace() """Function to train the model.""" self.generator = Generator(self.model, self.config) self.monitor = Monitor.FILTERED_MEAN_RANK for cur_epoch_idx in range(self.config.epochs): self._logger.info("Epoch[%d/%d]" % (cur_epoch_idx, self.config.epochs)) self.train_model_epoch(cur_epoch_idx) if cur_epoch_idx % self.config.test_step == 0: self.model.eval() with torch.no_grad(): metrics = self.evaluator.mini_test(cur_epoch_idx) if self.early_stopper.should_stop(metrics): ### Early Stop Mechanism ### start to check if the metric is still improving after each mini-test. ### Example, if test_step == 5, the trainer will check metrics every 5 epoch. break # store the best model weights. if self.config.save_model: if self.best_metric is None: self.best_metric = metrics self.save_model() else: if self.monitor == Monitor.MEAN_RANK or self.monitor == Monitor.FILTERED_MEAN_RANK: is_better = self.best_metric[self.monitor.value] > metrics[self.monitor.value] else: is_better = self.best_metric[self.monitor.value] < metrics[self.monitor.value] if is_better: self.save_model() self.best_metric = metrics self.model.eval() with torch.no_grad(): self.evaluator.full_test(cur_epoch_idx) self.evaluator.metric_calculator.save_test_summary(self.model.model_name) self.generator.stop() self.save_training_result() # if self.config.save_model: # self.save_model() if self.config.disp_result: self.display() self.export_embeddings() return cur_epoch_idx # the runned epoches.
[docs] def tune_model(self): """Function to tune the model.""" current_loss = float("inf") self.generator = Generator(self.model, self.config) self.evaluator = Evaluator(self.model, self.config, tuning=True) for cur_epoch_idx in range(self.config.epochs): current_loss = self.train_model_epoch(cur_epoch_idx, tuning=True) self.model.eval() with torch.no_grad(): self.evaluator.full_test(cur_epoch_idx) self.generator.stop() return current_loss
[docs] def train_model_epoch(self, epoch_idx, tuning=False): """Function to train the model for one epoch.""" acc_loss = 0 num_batch = self.config.tot_train_triples // self.config.batch_size if not self.config.debug else 10 self.generator.start_one_epoch(num_batch) progress_bar = tqdm(range(num_batch)) for _ in progress_bar: data = list(next(self.generator)) self.model.train() self.optimizer.zero_grad() if self.model.training_strategy == TrainingStrategy.PROJECTION_BASED: h = torch.LongTensor(data[0]).to(self.config.device) r = torch.LongTensor(data[1]).to(self.config.device) t = torch.LongTensor(data[2]).to(self.config.device) hr_t = data[3].to(self.config.device) tr_h = data[4].to(self.config.device) loss = self.train_step_projection(h, r, t, hr_t, tr_h) elif self.model.training_strategy == TrainingStrategy.POINTWISE_BASED: h = torch.LongTensor(data[0]).to(self.config.device) r = torch.LongTensor(data[1]).to(self.config.device) t = torch.LongTensor(data[2]).to(self.config.device) y = torch.LongTensor(data[3]).to(self.config.device) loss = self.train_step_pointwise(h, r, t, y) elif self.model.training_strategy == TrainingStrategy.PAIRWISE_BASED: pos_h = torch.LongTensor(data[0]).to(self.config.device) pos_r = torch.LongTensor(data[1]).to(self.config.device) pos_t = torch.LongTensor(data[2]).to(self.config.device) neg_h = torch.LongTensor(data[3]).to(self.config.device) neg_r = torch.LongTensor(data[4]).to(self.config.device) neg_t = torch.LongTensor(data[5]).to(self.config.device) loss = self.train_step_pairwise(pos_h, pos_r, pos_t, neg_h, neg_r, neg_t) else: raise NotImplementedError("Unknown training strategy: %s" % self.model.training_strategy) loss.backward() self.optimizer.step() acc_loss += loss.item() if not tuning: progress_bar.set_description('acc_loss: %f, cur_loss: %f'% (acc_loss, loss)) self.training_results.append([epoch_idx, acc_loss]) return acc_loss
def enter_interactive_mode(self): self.build_model() self.load_model() self._logger.info("""The training/loading of the model has finished! Now enter interactive mode :) ----- Example 1: trainer.infer_tails(1,10,topk=5)""") self.infer_tails(1, 10, topk=5) self._logger.info("""----- Example 2: trainer.infer_heads(10,20,topk=5)""") self.infer_heads(10, 20, topk=5) self._logger.info("""----- Example 3: trainer.infer_rels(1,20,topk=5)""") self.infer_rels(1, 20, topk=5) def exit_interactive_mode(self): self._logger.info("Thank you for trying out inference interactive script :)") def infer_tails(self, h, r, topk=5): tails = self.evaluator.test_tail_rank(h, r, topk).detach().cpu().numpy() idx2ent = self.config.knowledge_graph.read_cache_data('idx2entity') idx2rel = self.config.knowledge_graph.read_cache_data('idx2relation') logs = [ "", "(head, relation)->({},{}) :: Inferred tails->({})".format(h, r, ",".join([str(i) for i in tails])), "", "head: %s" % idx2ent[h], "relation: %s" % idx2rel[r], ] for idx, tail in enumerate(tails): logs.append("%dth predicted tail: %s" % (idx, idx2ent[tail])) self._logger.info("\n".join(logs)) return {tail: idx2ent[tail] for tail in tails} def infer_heads(self, r, t, topk=5): heads = self.evaluator.test_head_rank(r, t, topk).detach().cpu().numpy() idx2ent = self.config.knowledge_graph.read_cache_data('idx2entity') idx2rel = self.config.knowledge_graph.read_cache_data('idx2relation') logs = [ "", "(relation,tail)->({},{}) :: Inferred heads->({})".format(t, r, ",".join([str(i) for i in heads])), "", "tail: %s" % idx2ent[t], "relation: %s" % idx2rel[r], ] for idx, head in enumerate(heads): logs.append("%dth predicted head: %s" % (idx, idx2ent[head])) self._logger.info("\n".join(logs)) return {head: idx2ent[head] for head in heads} def infer_rels(self, h, t, topk=5): if self.model.model_name.lower() in ["proje_pointwise", "conve", "tucker"]: self._logger.info("%s model doesn't support relation inference in nature.") return {} rels = self.evaluator.test_rel_rank(h, t, topk).detach().cpu().numpy() idx2ent = self.config.knowledge_graph.read_cache_data('idx2entity') idx2rel = self.config.knowledge_graph.read_cache_data('idx2relation') logs = [ "", "(head,tail)->({},{}) :: Inferred rels->({})".format(h, t, ",".join([str(i) for i in rels])), "", "head: %s" % idx2ent[h], "tail: %s" % idx2ent[t], ] for idx, rel in enumerate(rels): logs.append("%dth predicted rel: %s" % (idx, idx2rel[rel])) self._logger.info("\n".join(logs)) return {rel: idx2rel[rel] for rel in rels} # ''' Procedural functions:'''
[docs] def save_model(self): """Function to save the model.""" saved_path = self.config.path_tmp / self.model.model_name saved_path.mkdir(parents=True, exist_ok=True) torch.save(self.model.state_dict(), str(saved_path / self.TRAINED_MODEL_FILE_NAME)) # Save hyper-parameters into a yaml file with the model save_path_config = saved_path / self.TRAINED_MODEL_CONFIG_NAME np.save(save_path_config, self.config)
[docs] def load_model(self, model_path=None): """Function to load the model.""" if model_path is None: model_path_file = self.config.path_tmp / self.model.model_name / self.TRAINED_MODEL_FILE_NAME model_path_config = self.config.path_tmp / self.model.model_name / self.TRAINED_MODEL_CONFIG_NAME else: model_path = Path(model_path) model_path_file = model_path / self.TRAINED_MODEL_FILE_NAME model_path_config = model_path / self.TRAINED_MODEL_CONFIG_NAME if model_path_file.exists() and model_path_config.exists(): config_temp = np.load(model_path_config, allow_pickle=True).item() config_temp.__dict__['load_from_data'] = self.config.__dict__['load_from_data'] self.config = config_temp _, model_def = Importer().import_model_config(self.config.model_name.lower()) self.model = model_def(**self.config.__dict__) self.model.load_state_dict(torch.load(str(model_path_file))) self.model.eval() else: raise ValueError("Cannot load model from %s" % model_path_file)
[docs] def display(self): """Function to display embedding.""" options = {"ent_only_plot": True, "rel_only_plot": not self.config.plot_entity_only, "ent_and_rel_plot": not self.config.plot_entity_only} if self.config.plot_embedding: viz = Visualization(self.model, self.config, vis_opts=options) viz.plot_embedding(resultpath=self.config.path_figures, algos=self.model.model_name, show_label=False) if self.config.plot_training_result: viz = Visualization(self.model, self.config) viz.plot_train_result() if self.config.plot_testing_result: viz = Visualization(self.model, self.config) viz.plot_test_result()
[docs] def export_embeddings(self): """ Export embeddings in tsv and pandas pickled format. With tsvs (both label, vector files), you can: 1) Use those pretained embeddings for your applications. 2) Visualize the embeddings in this website to gain insights. (https://projector.tensorflow.org/) Pandas dataframes can be read with pd.read_pickle('desired_file.pickle') """ save_path = self.config.path_embeddings / self.model.model_name save_path.mkdir(parents=True, exist_ok=True) idx2ent = self.config.knowledge_graph.read_cache_data('idx2entity') idx2rel = self.config.knowledge_graph.read_cache_data('idx2relation') with open(str(save_path / "ent_labels.tsv"), 'w') as l_export_file: for label in idx2ent.values(): l_export_file.write(label + "\n") with open(str(save_path / "rel_labels.tsv"), 'w') as l_export_file: for label in idx2rel.values(): l_export_file.write(label + "\n") for named_embedding in self.model.parameter_list: all_ids = list(range(0, int(named_embedding.weight.shape[0]))) stored_name = named_embedding.name if len(named_embedding.weight.shape) == 2: all_embs = named_embedding.weight.detach().detach().cpu().numpy() with open(str(save_path / ("%s.tsv" % stored_name)), 'w') as v_export_file: for idx in all_ids: v_export_file.write("\t".join([str(x) for x in all_embs[idx]]) + "\n")
[docs] def save_training_result(self): """Function that saves training result""" files = os.listdir(str(self.config.path_result)) l = len([f for f in files if self.model.model_name in f if 'Training' in f]) df = pd.DataFrame(self.training_results, columns=['Epochs', 'Loss']) with open(str(self.config.path_result / (self.model.model_name + '_Training_results_' + str(l) + '.csv')), 'w') as fh: df.to_csv(fh)