Source code for pykg2vec.utils.visualization

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module is for visualizing the results
"""
import os
import seaborn
import torch
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd

from sklearn.manifold import TSNE
from matplotlib import colors as mcolors
from pykg2vec.utils.logger import Logger

seaborn.set_style("darkgrid")

[docs]class Visualization: """Class to aid in visualizing the results and embddings. Args: model (object): Model object vis_opts (list): Options for visualization. sess (object): TensorFlow session object, initialized by the trainer. Examples: >>> from pykg2vec.utils.visualization import Visualization >>> from pykg2vec.utils.trainer import Trainer >>> from pykg2vec.models.TransE import TransE >>> model = TransE() >>> trainer = Trainer(model=model) >>> trainer.build_model() >>> trainer.train_model() >>> viz = Visualization(model=model) >>> viz.plot_train_result() """ _logger = Logger().get_logger(__name__) def __init__(self, model, config, vis_opts=None): if vis_opts: self.ent_only_plot = vis_opts["ent_only_plot"] self.rel_only_plot = vis_opts["rel_only_plot"] self.ent_and_rel_plot = vis_opts["ent_and_rel_plot"] else: self.ent_only_plot = False self.rel_only_plot = False self.ent_and_rel_plot = False self.model = model self.config = config self.algo_list = ['ANALOGY', 'Complex', 'ComplexN3', 'ConvE', 'CP', 'DistMult', 'DistMult2', 'HoLE', 'KG2E', 'NTN', 'ProjE_pointwise', 'Rescal', 'RotatE', 'SimplE_avg', 'SimplE_ignr', 'SLM', 'SME_Bilinear', 'SME_Linear', 'TransD', 'TransE', 'TransH', 'TransM', 'TransR', 'TuckER'] self.h_name = [] self.r_name = [] self.t_name = [] self.h_emb = [] self.r_emb = [] self.t_emb = [] self.h_proj_emb = [] self.r_proj_emb = [] self.t_proj_emb = [] if self.model is not None: self.validation_triples_ids = self.config.knowledge_graph.read_cache_data('triplets_valid') self.idx2entity = self.config.knowledge_graph.read_cache_data('idx2entity') self.idx2relation = self.config.knowledge_graph.read_cache_data('idx2relation') self.get_idx_n_emb()
[docs] def get_idx_n_emb(self): """Function to get the integer ids and the embedding.""" idx = np.random.choice(len(self.validation_triples_ids), self.config.disp_triple_num) triples = [] for i, _ in enumerate(idx): triples.append(self.validation_triples_ids[idx[i]]) for t in triples: self.h_name.append(self.idx2entity[t.h]) self.r_name.append(self.idx2relation[t.r]) self.t_name.append(self.idx2entity[t.t]) emb_h, emb_r, emb_t = self.model.embed(torch.LongTensor([t.h]).to(self.config.device), torch.LongTensor([t.r]).to(self.config.device), torch.LongTensor([t.t]).to(self.config.device)) self.h_emb.append(emb_h) self.r_emb.append(emb_r) self.t_emb.append(emb_t) if self.ent_and_rel_plot: try: emb_h, emb_r, emb_t = self.model.embed(torch.LongTensor([t.h]).to(self.config.device), torch.LongTensor([t.r]).to(self.config.device), torch.LongTensor([t.t]).to(self.config.device)) self.h_proj_emb.append(emb_h) self.r_proj_emb.append(emb_r) self.t_proj_emb.append(emb_t) except Exception as e: self._logger.exception(e)
[docs] def plot_embedding(self, resultpath=None, algos=None, show_label=False, disp_num_r_n_e=20): """Function to plot the embedding. Args: resultpath (str): Path where the result will be saved. show_label (bool): If True, will display the labels. algos (str): Name of the algorithms that generated the embedding. disp_num_r_n_e (int): Total number of entities to display for head, tail and relation. """ assert self.model is not None, 'Please provide a model!' if self.ent_only_plot: x = torch.cat(self.h_emb + self.t_emb, dim=0) ent_names = np.concatenate((self.h_name, self.t_name), axis=0) self._logger.info("\t Reducing dimension using TSNE to 2!") x = TSNE(n_components=2).fit_transform(x.detach().cpu()) x = np.asarray(x) ent_names = np.asarray(ent_names) self.draw_embedding(x, ent_names, resultpath, algos + '_entity_plot', show_label) if self.rel_only_plot: x = torch.cat(self.r_emb, dim=0) self._logger.info("\t Reducing dimension using TSNE to 2!") x = TSNE(n_components=2).fit_transform(x.detach().cpu()) self.draw_embedding(x, self.r_name, resultpath, algos + '_rel_plot', show_label) if self.ent_and_rel_plot: length = len(self.h_proj_emb) x = torch.cat(self.h_proj_emb + self.r_proj_emb + self.t_proj_emb, dim=0) self._logger.info("\t Reducing dimension using TSNE to 2!") x = TSNE(n_components=2).fit_transform(x.detach().cpu()) h_embs = x[:length, :] r_embs = x[length:2 * length, :] t_embs = x[2 * length:3 * length, :] self.draw_embedding_rel_space(h_embs[:disp_num_r_n_e], r_embs[:disp_num_r_n_e], t_embs[:disp_num_r_n_e], self.h_name[:disp_num_r_n_e], self.r_name[:disp_num_r_n_e], self.t_name[:disp_num_r_n_e], resultpath, algos + '_ent_n_rel_plot', show_label)
[docs] def plot_train_result(self): """Function to plot the training result.""" algo = self.algo_list path = self.config.path_result result = self.config.path_figures data = [self.config.dataset_name] files = os.listdir(str(path)) files_lwcase = [f.lower() for f in files] for d in data: df = pd.DataFrame() for a in algo: file_no = len([c for c in files_lwcase if a.lower() in c if 'training' in c]) if file_no < 1: continue file_path = str(path / (a.lower() + '_Training_results_' + str(file_no - 1) + '.csv')) if os.path.exists(file_path): with open(str(path / (a.lower() + '_Training_results_' + str(file_no - 1) + '.csv')), 'r') as fh: df_2 = pd.read_csv(fh) if df.empty: df['Epochs'] = df_2['Epochs'] df['Loss'] = df_2['Loss'] df['Algorithm'] = [a] * len(df_2) else: df_3 = pd.DataFrame() df_3['Epochs'] = df_2['Epochs'] df_3['Loss'] = df_2['Loss'] df_3['Algorithm'] = [a] * len(df_2) frames = [df, df_3] df = pd.concat(frames) plt.figure() seaborn.lineplot(x="Epochs", y="Loss", hue="Algorithm", markers=True, dashes=False, data=df) files = os.listdir(str(result)) files_lwcase = [f.lower() for f in files] file_no = len([c for c in files_lwcase if d.lower() in c if 'training' in c]) plt.savefig(str(result / (d + '_training_loss_plot_' + str(file_no) + '.pdf')), bbox_inches='tight', dpi=300)
# plt.show()
[docs] def plot_test_result(self): """Function to plot the testing result.""" algo = self.algo_list path = self.config.path_result result = self.config.path_figures data = [self.config.dataset_name] hits = self.config.hits assert path is not None and algo is not None and data is not None, 'Please provide valid path, algorithm and dataset!' files = os.listdir(str(path)) # files_lwcase = [f.lower() for f in files if 'Testing' in f] # self._logger.info(files_lwcase) for d in data: df = pd.DataFrame() for a in algo: file_algo = [c for c in files if a.lower() in c.lower() if 'testing' in c.lower()] if not file_algo: continue with open(str(path / file_algo[-1]), 'r') as fh: df_2 = pd.read_csv(fh) if df.empty: df['Algorithm'] = [a] * len(df_2) df['Epochs'] = df_2['Epoch'] df['Mean Rank'] = df_2['Mean Rank'] df['Filt Mean Rank'] = df_2['Filtered Mean Rank'] for hit in hits: df['Hits' + str(hit)] = df_2['Hit-%d Ratio'%hit] df['Filt Hits' + str(hit)] = df_2['Filtered Hit-%d Ratio'%hit] else: df_3 = pd.DataFrame() df_3['Algorithm'] = [a] * len(df_2) df_3['Epochs'] = df_2['Epoch'] df_3['Mean Rank'] = df_2['Mean Rank'] df_3['Filt Mean Rank'] = df_2['Filtered Mean Rank'] for hit in hits: df_3['Hits' + str(hit)] = df_2['Hit-%d Ratio'%hit] df_3['Filt Hits' + str(hit)] = df_2['Filtered Hit-%d Ratio'%hit] frames = [df, df_3] df = pd.concat(frames) files = os.listdir(str(result)) df_4 = df.loc[df['Epochs'] == max(df['Epochs'])] df_4 = df_4.loc[:, df_4.columns != 'Epochs'] file_no = len( [c for c in files if d.lower() in c.lower() if 'testing' in c.lower() if 'latex' in c.lower()]) with open(str(result / (d + '_testing_latex_table_' + str(file_no + 1) + '.txt')), 'w') as fh: fh.write(df_4.to_latex(index=False)) file_no = len( [c for c in files if d.lower() in c.lower() if 'testing' in c.lower() if 'table' in c.lower() if 'csv' in c.lower()]) with open(str(result / (d + '_testing_table_' + str(file_no + 1) + '.csv')), 'w') as fh: df_4.to_csv(fh, index=False) df_5 = pd.DataFrame(columns=['Metrics', 'Algorithm', 'Score']) metrics = [f for f in df_4.columns if f != 'Algorithm'] for i in range(len(df_4)): if df_5.empty: df_5['Algorithm'] = [df_4.iloc[i]['Algorithm']] * len(metrics) df_5['Metrics'] = metrics df_5['Score'] = df_4.iloc[i][metrics].values else: df_t = pd.DataFrame() df_t['Algorithm'] = [df_4.iloc[i]['Algorithm']] * len(metrics) df_t['Metrics'] = metrics df_t['Score'] = df_4.iloc[i][metrics].values frame = [df_5, df_t] df_5 = pd.concat(frame) df_6 = df_5[~df_5['Metrics'].str.contains('Hits')] plt.figure() flatui = ["#d46a7e", "#d5b60a", "#9b59b6", "#3498db", "#95a5a6", "#34495e", "#2ecc71", "#e74c3c"] g = seaborn.barplot(x="Metrics", y='Score', hue="Algorithm", palette=flatui, data=df_6) g.legend(loc='upper center', bbox_to_anchor=(0.5, 1.14), ncol=6) g.tick_params(labelsize=6) # ax = seaborn.lineplot(x="Metrics", y='Score', hue="Algorithm", # markers=True, dashes=False, data=df_5) files_lwcase = [f.lower() for f in files] file_no = len([c for c in files_lwcase if d.lower() in c if 'testing' in c if 'rank_plot' in c]) plt.savefig(str(result / (d + '_testing_rank_plot_' + str(file_no + 1) + '.pdf')), bbox_inches='tight', dpi=300) # plt.show() df_6 = df_5[df_5['Metrics'].str.contains('Hits')] plt.figure() flatui = ["#3498db", "#95a5a6", "#34495e", "#2ecc71", "#e74c3c", "#d46a7e", "#d5b60a", "#9b59b6"] g = seaborn.barplot(x="Metrics", y='Score', hue="Algorithm", palette=flatui, data=df_6) g.legend(loc='upper center', bbox_to_anchor=(0.5, 1.14), ncol=6) g.tick_params(labelsize=6) files_lwcase = [f.lower() for f in files] file_no = len([c for c in files_lwcase if d.lower() in c if 'testing' in c if 'hits_plot' in c]) plt.savefig(str(result / (d + '_testing_hits_plot_' + str(file_no + 1) + '.pdf')), bbox_inches='tight', dpi=300)
# plt.show()
[docs] @staticmethod def draw_embedding(embs, names, resultpath, algos, show_label): """Function to draw the embedding. Args: embs (matrix): Two dimesnional embeddings. names (list):List of string name. resultpath (str):Path where the result will be save. algos (str): Name of the algorithms which generated the algorithm. show_label (bool): If True, prints the string names of the entities and relations. """ pos = {} node_color_mp = {} unique_ent = set(names) colors = list(dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS).keys()) tot_col = len(colors) j = 0 for i, e in enumerate(unique_ent): node_color_mp[e] = colors[j] j += 1 if j >= tot_col: j = 0 G = nx.Graph() hm_ent = {} for i, ent in enumerate(names): hm_ent[i] = ent G.add_node(i) pos[i] = embs[i] colors = [] for n in list(G.nodes): colors.append(node_color_mp[hm_ent[n]]) plt.figure() nodes_draw = nx.draw_networkx_nodes(G, pos, node_color=colors, node_size=50) nodes_draw.set_edgecolor('k') if show_label: nx.draw_networkx_labels(G, pos, font_size=8) if not os.path.exists(resultpath): os.mkdir(resultpath) files = os.listdir(resultpath) file_no = len( [c for c in files if algos + '_embedding_plot' in c]) filename = algos + '_embedding_plot_' + str(file_no) + '.png' plt.savefig(str(resultpath / filename), bbox_inches='tight', dpi=300)
# plt.show()
[docs] @staticmethod def draw_embedding_rel_space(h_emb, r_emb, t_emb, h_name, r_name, t_name, resultpath, algos, show_label): """Function to draw the embedding in relation space. Args: h_emb (matrix): Two dimesnional embeddings of head. r_emb (matrix): Two dimesnional embeddings of relation. t_emb (matrix): Two dimesnional embeddings of tail. h_name (list):List of string name of the head. r_name (list):List of string name of the relation. t_name (list):List of string name of the tail. resultpath (str):Path where the result will be save. algos (str): Name of the algorithms which generated the algorithm. show_label (bool): If True, prints the string names of the entities and relations. """ pos = {} node_color_mp_ent = {} node_color_mp_rel = {} unique_ent = set(h_name) | set(t_name) unique_rel = set(r_name) colors = list(dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS).keys()) tot_col = len(colors) j = 0 for i, e in enumerate(unique_ent): node_color_mp_ent[e] = colors[j] j += 1 if j >= tot_col: j = 0 tot_col = len(colors) j = 0 for i, r in enumerate(unique_rel): node_color_mp_rel[r] = colors[j] j += 1 if j >= tot_col: j = 0 G = nx.DiGraph() idx = 0 head_colors = [] rel_colors = [] tail_colors = [] head_nodes = [] tail_nodes = [] rel_nodes = [] for i, _ in enumerate(h_name): G.add_edge(idx, idx + 1) G.add_edge(idx + 1, idx + 2) head_nodes.append(idx) rel_nodes.append(idx + 1) tail_nodes.append(idx + 2) head_colors.append(node_color_mp_ent[h_name[i]]) rel_colors.append(node_color_mp_rel[r_name[i]]) tail_colors.append(node_color_mp_ent[t_name[i]]) pos[idx] = h_emb[i] pos[idx + 1] = r_emb[i] pos[idx + 2] = t_emb[i] idx += 3 plt.figure() nodes_draw = nx.draw_networkx_nodes(G, pos, nodelist=head_nodes, node_color=head_colors, node_shape='o', node_size=50) nodes_draw.set_edgecolor('k') nodes_draw = nx.draw_networkx_nodes(G, pos, nodelist=rel_nodes, node_color=rel_colors, node_size=50, node_shape='D') nodes_draw.set_edgecolor('k') nodes_draw = nx.draw_networkx_nodes(G, pos, nodelist=tail_nodes, node_color=tail_colors, node_shape='*', node_size=50) nodes_draw.set_edgecolor('k') if show_label: nx.draw_networkx_labels(G, pos, font_size=8) nx.draw_networkx_edges(G, pos, arrows=True, width=0.5, alpha=0.5) if not os.path.exists(resultpath): os.mkdir(resultpath) files = os.listdir(resultpath) file_no = len( [c for c in files if algos + '_embedding_plot' in c]) plt.savefig(str(resultpath / (algos + '_embedding_plot_' + str(file_no) + '.png')), bbox_inches='tight', dpi=300)
# plt.show()