#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module is for controlling knowledge graph
"""
import shutil
import pickle
import time
from collections import defaultdict
import numpy as np
from pykg2vec.utils.logger import Logger
from pykg2vec.data.datasets import (
FreebaseFB15k,
DeepLearning50a,
WordNet18,
WordNet18_RR,
YAGO3_10,
FreebaseFB15k_237,
Kinship,
Nations,
UMLS,
NELL_995,
UserDefinedDataset,
)
[docs]class Triple:
""" The class defines the datastructure of the knowledge graph triples.
Triple class is used to store the head, tail and relation triple in both its numerical id and
string form. It also stores the dictonary of (head, relation)=[tail1, tail2,..] and
(tail, relation)=[head1, head2, ...]
Args:
h (str or int): String or integer head entity.
r (str or int): String or integer relation entity.
t (str or int): String or integer tail entity.
Examples:
>>> from pykg2vec.data.kgcontroller import Triple
>>> trip1 = Triple(2,3,5)
>>> trip2 = Triple('Tokyo','isCapitalof','Japan')
"""
def __init__(self, h, r, t):
self.h = h
self.r = r
self.t = t
[docs] def set_ids(self, h, r, t):
""" This function assigns the head, relation and tail.
Args:
h (int): Integer head entity.
r (int): Integer relation entity.
t (int): Integer tail entity.
"""
self.h = h
self.r = r
self.t = t
[docs]class KnowledgeGraph:
""" The class is the main module that handles the knowledge graph.
KnowledgeGraph is responsible for downloading, parsing, processing and preparing
the training, testing and validation dataset.
Args:
dataset_name (str): Name of the datasets
custom_dataset_path (str): The path to custom dataset.
Attributes:
dataset_name (str): The name of the dataset.
dataset (object): The dataset object isntance.
triplets (dict): dictionary with three list of training, testing and validation triples.
relations (list):list of all the relations.
entities (list): List of all the entities.
entity2idx (dict): Dictionary for mapping string name of entities to unique numerical id.
idx2entity (dict): Dictionary for mapping the id to string.
relation2idx (dict): Dictionary for mapping the id to string.
idx2relation (dict): Dictionary for mapping the id to string.
hr_t (dict): Dictionary with set as a default key and list as values.
tr_h (dict): Dictionary with set as a default key and list as values.
hr_t_train (dict): Dictionary with set as a default key and list as values.
tr_h_train (dict): Dictionary with set as a default key and list as values.
relation_property (list): list storing the entities tied to a specific relation.
kg_meta (object): Object storing the statistics metadata of the dataset.
Examples:
>>> from pykg2vec.data.kgcontroller import KnowledgeGraph
>>> knowledge_graph = KnowledgeGraph(dataset='Freebase15k')
>>> knowledge_graph.prepare_data()
"""
_logger = Logger().get_logger(__name__)
def __init__(self, dataset='Freebase15k', custom_dataset_path=None):
self.dataset_name = dataset
if dataset.lower() == 'freebase15k' or dataset.lower() == 'fb15k':
self.dataset = FreebaseFB15k()
elif dataset.lower() == 'deeplearning50a' or dataset.lower() == 'dl50a':
self.dataset = DeepLearning50a()
elif dataset.lower() == 'wordnet18' or dataset.lower() == 'wn18':
self.dataset = WordNet18()
elif dataset.lower() == 'wordnet18_rr' or dataset.lower() == 'wn18_rr':
self.dataset = WordNet18_RR()
elif dataset.lower() == 'yago3_10' or dataset.lower() == 'yago':
self.dataset = YAGO3_10()
elif dataset.lower() == 'freebase15k_237' or dataset.lower() == 'fb15k_237':
self.dataset = FreebaseFB15k_237()
elif dataset.lower() == 'kinship' or dataset.lower() == 'ks':
self.dataset = Kinship()
elif dataset.lower() == 'nations':
self.dataset = Nations()
elif dataset.lower() == 'umls':
self.dataset = UMLS()
elif dataset.lower() == 'nell_995':
self.dataset = NELL_995()
elif custom_dataset_path is not None:
# if the dataset does not match with existing one, check if it exists in user's local space.
# if it still can't find corresponding folder, raise exception in UserDefinedDataset.__init__()
self.dataset = UserDefinedDataset(dataset, custom_dataset_path)
else:
raise ValueError("Unknown dataset: %s" % dataset)
# KG data structure stored in triplet format
self.triplets = {'train': [], 'test': [], 'valid': []}
self.triple_store = self.triplets
# TODO: should also have graph-based data structure for a KG.
self.relations = []
self.entities = []
self.entity2idx = {}
self.idx2entity = {}
self.relation2idx = {}
self.idx2relation = {}
self.hr_t = defaultdict(set)
self.tr_h = defaultdict(set)
self.hr_t_train = defaultdict(set)
self.tr_h_train = defaultdict(set)
self.hr_t_valid = defaultdict(set)
self.tr_h_valid = defaultdict(set)
self.relation_property = []
if self.dataset.is_meta_cache_exists():
self.kg_meta = self.dataset.read_metadata()
else:
self.kg_meta = KGMetaData()
self.prepare_data()
def force_prepare_data(self):
shutil.rmtree(str(self.dataset.root_path), ignore_errors=True)
time.sleep(1)
self.__init__(dataset=self.dataset_name)
[docs] def prepare_data(self):
"""Function to prepare the dataset"""
if self.dataset.is_meta_cache_exists():
return
self.read_entities()
self.read_relations()
self.read_mappings()
self.read_triple_ids('train')
self.read_triple_ids('test')
self.read_triple_ids('valid')
self.read_hr_t()
self.read_tr_h()
self.read_hr_t_train()
self.read_tr_h_train()
self.read_hr_t_valid()
self.read_tr_h_valid()
self.read_relation_property()
self.kg_meta.tot_relation = len(self.relations)
self.kg_meta.tot_entity = len(self.entities)
self.kg_meta.tot_valid_triples = len(self.triplets['valid'])
self.kg_meta.tot_test_triples = len(self.triplets['test'])
self.kg_meta.tot_train_triples = len(self.triplets['train'])
self.kg_meta.tot_triple = self.kg_meta.tot_valid_triples + \
self.kg_meta.tot_test_triples + \
self.kg_meta.tot_train_triples
self._cache_data()
def _cache_data(self):
"""Function to cache the prepared dataset in the memory"""
with open(str(self.dataset.cache_metadata_path), 'wb') as f:
pickle.dump(self.kg_meta, f)
with open(str(self.dataset.cache_triplet_paths['train']), 'wb') as f:
pickle.dump(self.triplets['train'], f)
with open(str(self.dataset.cache_triplet_paths['test']), 'wb') as f:
pickle.dump(self.triplets['test'], f)
with open(str(self.dataset.cache_triplet_paths['valid']), 'wb') as f:
pickle.dump(self.triplets['valid'], f)
with open(str(self.dataset.cache_hr_t_path), 'wb') as f:
pickle.dump(self.hr_t, f)
with open(str(self.dataset.cache_tr_h_path), 'wb') as f:
pickle.dump(self.tr_h, f)
with open(str(self.dataset.cache_hr_t_train_path), 'wb') as f:
pickle.dump(self.hr_t_train, f)
with open(str(self.dataset.cache_tr_h_train_path), 'wb') as f:
pickle.dump(self.tr_h_train, f)
with open(str(self.dataset.cache_idx2entity_path), 'wb') as f:
pickle.dump(self.idx2entity, f)
with open(str(self.dataset.cache_idx2relation_path), 'wb') as f:
pickle.dump(self.idx2relation, f)
with open(str(self.dataset.cache_relation2idx_path), 'wb') as f:
pickle.dump(self.relation2idx, f)
with open(str(self.dataset.cache_entity2idx_path), 'wb') as f:
pickle.dump(self.entity2idx, f)
with open(str(self.dataset.cache_relationproperty_path), 'wb') as f:
pickle.dump(self.relation_property, f)
[docs] def read_cache_data(self, key):
"""Function to read the cached dataset from the memory"""
if key == 'triplets_train':
with open(str(self.dataset.cache_triplet_paths['train']), 'rb') as f:
triplets = pickle.load(f)
return triplets
elif key == 'triplets_test':
with open(str(self.dataset.cache_triplet_paths['test']), 'rb') as f:
triplets = pickle.load(f)
return triplets
elif key == 'triplets_valid':
with open(str(self.dataset.cache_triplet_paths['valid']), 'rb') as f:
triplets = pickle.load(f)
return triplets
elif key == 'hr_t':
with open(str(self.dataset.cache_hr_t_path), 'rb') as f:
hr_t = pickle.load(f)
return hr_t
elif key == 'tr_h':
with open(str(self.dataset.cache_tr_h_path), 'rb') as f:
tr_h = pickle.load(f)
return tr_h
elif key == 'hr_t_train':
with open(str(self.dataset.cache_hr_t_train_path), 'rb') as f:
hr_t_train = pickle.load(f)
return hr_t_train
elif key == 'tr_h_train':
with open(str(self.dataset.cache_tr_h_train_path), 'rb') as f:
tr_h_train = pickle.load(f)
return tr_h_train
elif key == 'idx2entity':
with open(str(self.dataset.cache_idx2entity_path), 'rb') as f:
idx2entity = pickle.load(f)
return idx2entity
elif key == 'idx2relation':
with open(str(self.dataset.cache_idx2relation_path), 'rb') as f:
idx2relation = pickle.load(f)
return idx2relation
elif key == 'entity2idx':
with open(str(self.dataset.cache_entity2idx_path), 'rb') as f:
entity2idx = pickle.load(f)
return entity2idx
elif key == 'relation2idx':
with open(str(self.dataset.cache_relation2idx_path), 'rb') as f:
relation2idx = pickle.load(f)
return relation2idx
elif key == 'relationproperty':
with open(str(self.dataset.cache_relationproperty_path), 'rb') as f:
relation_property = pickle.load(f)
return relation_property
else:
raise ValueError('Unknown cache data key %s' % key)
[docs] def is_cache_exists(self):
"""Function to check if the dataset is cached in the memory"""
return self.dataset.is_meta_cache_exists()
[docs] def read_triplets(self, set_type):
'''
read triplets from txt files in dataset folder.
(in string format)
'''
triplets = self.triplets[set_type]
if len(triplets) == 0:
with open(str(self.dataset.data_paths[set_type]), 'r', encoding='utf-8') as file:
for line in file.readlines():
s, p, o = line.split('\t')
triplets.append(Triple(s.strip(), p.strip(), o.strip()))
return triplets
[docs] def read_entities(self):
""" Function to read the entities. """
if len(self.entities) == 0:
entities = set()
all_triplets = self.read_triplets('train') + \
self.read_triplets('valid') + \
self.read_triplets('test')
for triplet in all_triplets:
entities.add(triplet.h)
entities.add(triplet.t)
self.entities = np.sort(list(entities))
return self.entities
[docs] def read_relations(self):
""" Function to read the relations. """
if len(self.relations) == 0:
relations = set()
all_triplets = self.read_triplets('train') + \
self.read_triplets('valid') + \
self.read_triplets('test')
for triplet in all_triplets:
relations.add(triplet.r)
self.relations = np.sort(list(relations))
return self.relations
[docs] def read_mappings(self):
""" Function to generate the mapping from string name to integer ids. """
self.entity2idx = {v: k for k, v in enumerate(self.read_entities())} ##
self.idx2entity = {v: k for k, v in self.entity2idx.items()}
self.relation2idx = {v: k for k, v in enumerate(self.read_relations())} ##
self.idx2relation = {v: k for k, v in self.relation2idx.items()}
[docs] def read_triple_ids(self, set_type):
""" Function to read the triple idx.
Args:
set_type (str): Type of data, eithe train, test or valid.
"""
# assert entities can not be none
# assert relations can not be none
triplets = self.triplets[set_type]
entity2idx = self.entity2idx
relation2idx = self.relation2idx
if len(triplets) != 0:
for t in triplets:
t.set_ids(entity2idx[t.h], relation2idx[t.r], entity2idx[t.t])
return triplets
[docs] def read_hr_t(self):
""" Function to read the list of tails for the given head and relation pair. """
for set_type in self.triplets:
triplets = self.triplets[set_type]
for t in triplets:
self.hr_t[(t.h, t.r)].add(t.t)
return self.hr_t
[docs] def read_tr_h(self):
""" Function to read the list of heads for the given tail and relation pair. """
for set_type in self.triplets:
triplets = self.triplets[set_type]
for t in triplets:
self.tr_h[(t.t, t.r)].add(t.h)
return self.tr_h
[docs] def read_hr_t_train(self):
""" Function to read the list of tails for the given head and relation pair for the training set. """
triplets = self.triplets['train']
for t in triplets:
self.hr_t_train[(t.h, t.r)].add(t.t)
return self.hr_t_train
[docs] def read_tr_h_train(self):
""" Function to read the list of heads for the given tail and relation pair for the training set. """
triplets = self.triplets['train']
for t in triplets:
self.tr_h_train[(t.t, t.r)].add(t.h)
return self.tr_h_train
[docs] def read_hr_t_valid(self):
""" Function to read the list of tails for the given head and relation pair for the valid set. """
triplets = self.triplets['valid']
for t in triplets:
self.hr_t_valid[(t.h, t.r)].add(t.t)
return self.hr_t_valid
[docs] def read_tr_h_valid(self):
""" Function to read the list of heads for the given tail and relation pair for the valid set. """
triplets = self.triplets['valid']
for t in triplets:
self.tr_h_valid[(t.t, t.r)].add(t.h)
return self.tr_h_valid
[docs] def read_relation_property(self):
""" Function to read the relation property.
Returns:
list: Returns the list of relation property.
"""
relation_property_head = {x: [] for x in range(len(self.relations))}
relation_property_tail = {x: [] for x in range(len(self.relations))}
for t in self.triplets['train']:
relation_property_head[t.r].append(t.h)
relation_property_tail[t.r].append(t.t)
self.relation_property = {}
for x in relation_property_head.keys():
value_up = len(set(relation_property_tail[x]))
value_bot = len(set(relation_property_head[x])) + len(set(relation_property_tail[x]))
if value_bot == 0:
value = 0
else:
value = value_up / value_bot
self.relation_property[x] = value
return self.relation_property
# reserved for debugging
[docs] def dump(self):
""" Function to dump statistic information of a dataset """
# dump key information
dump = [
"",
"----------Metadata Info for Dataset:%s----------------" % self.dataset_name,
"Total Training Triples :%s" % self.kg_meta.tot_train_triples,
"Total Testing Triples :%s" % self.kg_meta.tot_test_triples,
"Total validation Triples :%s" % self.kg_meta.tot_valid_triples,
"Total Entities :%s" % self.kg_meta.tot_entity,
"Total Relations :%s" % self.kg_meta.tot_relation,
"---------------------------------------------",
"",
]
self._logger.info("\n".join(dump))
return dump