Source code for pykg2vec.models.pointwise

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import numpy as np
from numpy.random import RandomState

from pykg2vec.models.KGMeta import PointwiseModel
from pykg2vec.models.Domain import NamedEmbedding
from pykg2vec.utils.criterion import Criterion


[docs]class ANALOGY(PointwiseModel): """ `Analogical Inference for Multi-relational Embeddings`_ Args: config (object): Model configuration parameters. .. _Analogical Inference for Multi-relational Embeddings: http://proceedings.mlr.press/v70/liu17d/liu17d.pdf """ def __init__(self, **kwargs): super(ANALOGY, self).__init__(self.__class__.__name__.lower()) param_list = ["tot_entity", "tot_relation", "hidden_size", "lmbda"] param_dict = self.load_params(param_list, kwargs) self.__dict__.update(param_dict) k = self.hidden_size self.ent_embeddings = NamedEmbedding("ent_embedding", self.tot_entity, k) self.rel_embeddings = NamedEmbedding("rel_embedding", self.tot_relation, k) self.ent_embeddings_real = NamedEmbedding("emb_e_real", self.tot_entity, k // 2) self.ent_embeddings_img = NamedEmbedding("emb_e_img", self.tot_entity, k // 2) self.rel_embeddings_real = NamedEmbedding("emb_rel_real", self.tot_relation, k // 2) self.rel_embeddings_img = NamedEmbedding("emb_rel_img", self.tot_relation, k // 2) nn.init.xavier_uniform_(self.ent_embeddings.weight) nn.init.xavier_uniform_(self.rel_embeddings.weight) nn.init.xavier_uniform_(self.ent_embeddings_real.weight) nn.init.xavier_uniform_(self.ent_embeddings_img.weight) nn.init.xavier_uniform_(self.rel_embeddings_real.weight) nn.init.xavier_uniform_(self.rel_embeddings_img.weight) self.parameter_list = [ self.ent_embeddings, self.rel_embeddings, self.ent_embeddings_real, self.ent_embeddings_img, self.rel_embeddings_real, self.rel_embeddings_img, ] self.loss = Criterion.pointwise_logistic
[docs] def embed(self, h, r, t): """Function to get the embedding value. Args: h (Tensor): Head entities ids. r (Tensor): Relation ids of the triple. t (Tensor): Tail entity ids of the triple. Returns: Tensors: Returns head, relation and tail embedding Tensors. """ h_emb = self.ent_embeddings(h) r_emb = self.rel_embeddings(r) t_emb = self.ent_embeddings(t) return h_emb, r_emb, t_emb
[docs] def embed_complex(self, h, r, t): """Function to get the embedding value. Args: h (Tensor): Head entities ids. r (Tensor): Relation ids of the triple. t (Tensor): Tail entity ids of the triple. Returns: Tensors: Returns real and imaginary values of head, relation and tail embedding. """ h_emb_real = self.ent_embeddings_real(h) h_emb_img = self.ent_embeddings_img(h) r_emb_real = self.rel_embeddings_real(r) r_emb_img = self.rel_embeddings_img(r) t_emb_real = self.ent_embeddings_real(t) t_emb_img = self.ent_embeddings_img(t) return h_emb_real, h_emb_img, r_emb_real, r_emb_img, t_emb_real, t_emb_img
[docs] def forward(self, h, r, t): h_e, r_e, t_e = self.embed(h, r, t) h_e_real, h_e_img, r_e_real, r_e_img, t_e_real, t_e_img = self.embed_complex(h, r, t) complex_loss = -(h_e_real * t_e_real * r_e_real + h_e_img * t_e_img * r_e_real + h_e_real * t_e_img * r_e_img - h_e_img * t_e_real * r_e_img).sum(axis=-1) distmult_loss = -(h_e * r_e * t_e).sum(axis=-1) return complex_loss + distmult_loss
[docs] def get_reg(self, h, r, t, reg_type="F2"): h_e, r_e, t_e = self.embed(h, r, t) h_e_real, h_e_img, r_e_real, r_e_img, t_e_real, t_e_img = self.embed_complex(h, r, t) if reg_type.lower() == 'f2': regul_term = (h_e_real ** 2 + h_e_img ** 2 + r_e_real ** 2 + r_e_img ** 2 + t_e_real ** 2 + t_e_img ** 2).sum(axis=-1).mean() regul_term += (h_e ** 2 + r_e ** 2 + t_e ** 2).sum(axis=-1).mean() elif reg_type.lower() == 'n3': regul_term = (h_e_real ** 3 + h_e_img ** 3 + r_e_real ** 3 + r_e_img ** 3 + t_e_real ** 3 + t_e_img ** 3).sum(axis=-1).mean() regul_term += (h_e ** 3 + r_e ** 3 + t_e ** 3).sum(axis=-1).mean() else: raise NotImplementedError('Unknown regularizer type: %s' % reg_type) return self.lmbda*regul_term
[docs]class Complex(PointwiseModel): """ `Complex Embeddings for Simple Link Prediction`_ (ComplEx) is an enhanced version of DistMult in that it uses complex-valued embeddings to represent both entities and relations. Using the complex-valued embedding allows the defined scoring function in ComplEx to differentiate that facts with assymmetric relations. Args: config (object): Model configuration parameters. .. _Complex Embeddings for Simple Link Prediction: http://proceedings.mlr.press/v48/trouillon16.pdf """ def __init__(self, **kwargs): super(Complex, self).__init__(self.__class__.__name__.lower()) param_list = ["tot_entity", "tot_relation", "hidden_size", "lmbda"] param_dict = self.load_params(param_list, kwargs) self.__dict__.update(param_dict) num_total_ent = self.tot_entity num_total_rel = self.tot_relation k = self.hidden_size self.ent_embeddings_real = NamedEmbedding("emb_e_real", num_total_ent, k) self.ent_embeddings_img = NamedEmbedding("emb_e_img", num_total_ent, k) self.rel_embeddings_real = NamedEmbedding("emb_rel_real", num_total_rel, k) self.rel_embeddings_img = NamedEmbedding("emb_rel_img", num_total_rel, k) nn.init.xavier_uniform_(self.ent_embeddings_real.weight) nn.init.xavier_uniform_(self.ent_embeddings_img.weight) nn.init.xavier_uniform_(self.rel_embeddings_real.weight) nn.init.xavier_uniform_(self.rel_embeddings_img.weight) self.parameter_list = [ self.ent_embeddings_real, self.ent_embeddings_img, self.rel_embeddings_real, self.rel_embeddings_img, ] self.loss = Criterion.pointwise_logistic
[docs] def embed(self, h, r, t): """Function to get the embedding value. Args: h (Tensor): Head entities ids. r (Tensor): Relation ids of the triple. t (Tensor): Tail entity ids of the triple. Returns: Tensors: Returns real and imaginary values of head, relation and tail embedding. """ h_emb_real = self.ent_embeddings_real(h) h_emb_img = self.ent_embeddings_img(h) r_emb_real = self.rel_embeddings_real(r) r_emb_img = self.rel_embeddings_img(r) t_emb_real = self.ent_embeddings_real(t) t_emb_img = self.ent_embeddings_img(t) return h_emb_real, h_emb_img, r_emb_real, r_emb_img, t_emb_real, t_emb_img
[docs] def forward(self, h, r, t): h_e_real, h_e_img, r_e_real, r_e_img, t_e_real, t_e_img = self.embed(h, r, t) return -torch.sum(h_e_real * t_e_real * r_e_real + h_e_img * t_e_img * r_e_real + h_e_real * t_e_img * r_e_img - h_e_img * t_e_real * r_e_img, -1)
[docs] def get_reg(self, h, r, t, reg_type="F2"): h_e_real, h_e_img, r_e_real, r_e_img, t_e_real, t_e_img = self.embed(h, r, t) if reg_type.lower() == 'f2': regul_term = torch.mean(torch.sum(h_e_real ** 2, -1) + torch.sum(h_e_img ** 2, -1) + torch.sum(r_e_real ** 2, -1) + torch.sum(r_e_img ** 2, -1) + torch.sum(t_e_real ** 2, -1) + torch.sum(t_e_img ** 2, -1)) elif reg_type.lower() == 'n3': regul_term = torch.mean(torch.sum(h_e_real ** 3, -1) + torch.sum(h_e_img ** 3, -1) + torch.sum(r_e_real ** 3, -1) + torch.sum(r_e_img ** 3, -1) + torch.sum(t_e_real ** 3, -1) + torch.sum(t_e_img ** 3, -1)) else: raise NotImplementedError('Unknown regularizer type: %s' % reg_type) return self.lmbda*regul_term
[docs]class ComplexN3(Complex): """ `Complex Embeddings for Simple Link Prediction`_ (ComplEx) is an enhanced version of DistMult in that it uses complex-valued embeddings to represent both entities and relations. Using the complex-valued embedding allows the defined scoring function in ComplEx to differentiate that facts with assymmetric relations. Args: config (object): Model configuration parameters. .. _Complex Embeddings for Simple Link Prediction: http://proceedings.mlr.press/v48/trouillon16.pdf """ def __init__(self, **kwargs): super(ComplexN3, self).__init__(**kwargs) self.model_name = 'complexn3' self.loss = Criterion.pointwise_logistic
[docs] def get_reg(self, h, r, t, reg_type="N3"): h_e_real, h_e_img, r_e_real, r_e_img, t_e_real, t_e_img = self.embed(h, r, t) if reg_type.lower() == 'f2': regul_term = torch.mean(torch.sum(h_e_real.abs() ** 2, -1) + torch.sum(h_e_img.abs() ** 2, -1) + torch.sum(r_e_real.abs() ** 2, -1) + torch.sum(r_e_img.abs() ** 2, -1) + torch.sum(t_e_real.abs() ** 2, -1) + torch.sum(t_e_img.abs() ** 2, -1)) elif reg_type.lower() == 'n3': regul_term = torch.mean(torch.sum(h_e_real.abs() ** 3, -1) + torch.sum(h_e_img.abs() ** 3, -1) + torch.sum(r_e_real.abs() ** 3, -1) + torch.sum(r_e_img.abs() ** 3, -1) + torch.sum(t_e_real.abs() ** 3, -1) + torch.sum(t_e_img.abs() ** 3, -1)) else: raise NotImplementedError('Unknown regularizer type: %s' % reg_type) return self.lmbda*regul_term
[docs]class ConvKB(PointwiseModel): """ In `A Novel Embedding Model for Knowledge Base Completion Based on Convolutional Neural Network`_ (ConvKB), each triple (head entity, relation, tail entity) is represented as a 3-column matrix where each column vector represents a triple element Portion of the code based on daiquocnguyen_. Args: config (object): Model configuration parameters. .. _daiquocnguyen: https://github.com/daiquocnguyen/ConvKB .. _A Novel Embedding Model for Knowledge Base Completion Based on Convolutional Neural Network: https://www.aclweb.org/anthology/N18-2053 """ def __init__(self, **kwargs): super(ConvKB, self).__init__(self.__class__.__name__.lower()) param_list = ["tot_entity", "tot_relation", "hidden_size", "num_filters", "filter_sizes"] param_dict = self.load_params(param_list, kwargs) self.__dict__.update(param_dict) num_total_ent = self.tot_entity num_total_rel = self.tot_relation k = self.hidden_size num_filters = self.num_filters filter_sizes = self.filter_sizes device = kwargs["device"] self.ent_embeddings = NamedEmbedding("ent_embedding", num_total_ent, k) self.rel_embeddings = NamedEmbedding("rel_embedding", num_total_rel, k) nn.init.xavier_uniform_(self.ent_embeddings.weight) nn.init.xavier_uniform_(self.rel_embeddings.weight) self.parameter_list = [ self.ent_embeddings, self.rel_embeddings, ] self.conv_list = [nn.Conv2d(1, num_filters, (3, filter_size), stride=(1, 1)).to(device) for filter_size in filter_sizes] conv_out_dim = num_filters*sum([(k-filter_size+1) for filter_size in filter_sizes]) self.fc1 = nn.Linear(in_features=conv_out_dim, out_features=1, bias=True) self.loss = Criterion.pointwise_logistic
[docs] def embed(self, h, r, t): """Function to get the embedding value. Args: h (Tensor): Head entities ids. r (Tensor): Relation ids of the triple. t (Tensor): Tail entity ids of the triple. Returns: Tensors: Returns head, relation and tail embedding Tensors. """ emb_h = self.ent_embeddings(h) emb_r = self.rel_embeddings(r) emb_t = self.ent_embeddings(t) return emb_h, emb_r, emb_t
[docs] def forward(self, h, r, t): h_emb, r_emb, t_emb = self.embed(h, r, t) first_dimen = list(h_emb.shape)[0] stacked_h = torch.unsqueeze(h_emb, dim=1) stacked_r = torch.unsqueeze(r_emb, dim=1) stacked_t = torch.unsqueeze(t_emb, dim=1) stacked_hrt = torch.cat([stacked_h, stacked_r, stacked_t], dim=1) stacked_hrt = torch.unsqueeze(stacked_hrt, dim=1) # [b, 1, 3, k] stacked_hrt = [conv_layer(stacked_hrt) for conv_layer in self.conv_list] stacked_hrt = torch.cat(stacked_hrt, dim=3) stacked_hrt = stacked_hrt.view(first_dimen, -1) preds = self.fc1(stacked_hrt) preds = torch.squeeze(preds, dim=-1) return preds
[docs]class CP(PointwiseModel): """ `Canonical Tensor Decomposition for Knowledge Base Completion`_ Args: config (object): Model configuration parameters. .. _Canonical Tensor Decomposition for Knowledge Base Completion: http://proceedings.mlr.press/v80/lacroix18a/lacroix18a.pdf """ def __init__(self, **kwargs): super(CP, self).__init__(self.__class__.__name__.lower()) param_list = ["tot_entity", "tot_relation", "hidden_size", "lmbda"] param_dict = self.load_params(param_list, kwargs) self.__dict__.update(param_dict) num_total_ent = self.tot_entity num_total_rel = self.tot_relation k = self.hidden_size self.sub_embeddings = NamedEmbedding("sub_embedding", num_total_ent, k) self.rel_embeddings = NamedEmbedding("rel_embedding", num_total_rel, k) self.obj_embeddings = NamedEmbedding("obj_embedding", num_total_ent, k) nn.init.xavier_uniform_(self.sub_embeddings.weight) nn.init.xavier_uniform_(self.rel_embeddings.weight) nn.init.xavier_uniform_(self.obj_embeddings.weight) self.parameter_list = [ self.sub_embeddings, self.rel_embeddings, self.obj_embeddings, ] self.loss = Criterion.pointwise_logistic
[docs] def embed(self, h, r, t): """Function to get the embedding value. Args: h (Tensor): Head entities ids. r (Tensor): Relation ids of the triple. t (Tensor): Tail entity ids of the triple. Returns: Tensors: Returns head, relation and tail embedding Tensors. """ emb_h = self.sub_embeddings(h) emb_r = self.rel_embeddings(r) emb_t = self.obj_embeddings(t) return emb_h, emb_r, emb_t
[docs] def forward(self, h, r, t): h_e, r_e, t_e = self.embed(h, r, t) return -torch.sum(h_e * r_e * t_e, -1)
[docs] def get_reg(self, h, r, t, reg_type='N3'): h_e, r_e, t_e = self.embed(h, r, t) if reg_type.lower() == 'f2': regul_term = torch.mean(torch.sum(h_e**2, -1) + torch.sum(r_e**2, -1) + torch.sum(t_e**2, -1)) elif reg_type.lower() == 'n3': regul_term = torch.mean(torch.sum(h_e**3, -1) + torch.sum(r_e**3, -1) + torch.sum(t_e**3, -1)) else: raise NotImplementedError('Unknown regularizer type: %s' % reg_type) return self.lmbda * regul_term
[docs]class DistMult(PointwiseModel): """ `EMBEDDING ENTITIES AND RELATIONS FOR LEARNING AND INFERENCE IN KNOWLEDGE BASES`_ (DistMult) is a simpler model comparing with RESCAL in that it simplifies the weight matrix used in RESCAL to a diagonal matrix. The scoring function used DistMult can capture the pairwise interactions between the head and the tail entities. However, DistMult has limitation on modeling asymmetric relations. Args: config (object): Model configuration parameters. .. _EMBEDDING ENTITIES AND RELATIONS FOR LEARNING AND INFERENCE IN KNOWLEDGE BASES: https://arxiv.org/pdf/1412.6575.pdf """ def __init__(self, **kwargs): super(DistMult, self).__init__(self.__class__.__name__.lower()) param_list = ["tot_entity", "tot_relation", "hidden_size", "lmbda"] param_dict = self.load_params(param_list, kwargs) self.__dict__.update(param_dict) num_total_ent = self.tot_entity num_total_rel = self.tot_relation k = self.hidden_size self.ent_embeddings = NamedEmbedding("ent_embedding", num_total_ent, k) self.rel_embeddings = NamedEmbedding("rel_embedding", num_total_rel, k) nn.init.xavier_uniform_(self.ent_embeddings.weight) nn.init.xavier_uniform_(self.rel_embeddings.weight) self.parameter_list = [ self.ent_embeddings, self.rel_embeddings, ] self.loss = Criterion.pointwise_logistic
[docs] def embed(self, h, r, t): """Function to get the embedding value. Args: h (Tensor): Head entities ids. r (Tensor): Relation ids of the triple. t (Tensor): Tail entity ids of the triple. Returns: Tensors: Returns head, relation and tail embedding Tensors. """ h_emb = self.ent_embeddings(h) r_emb = self.rel_embeddings(r) t_emb = self.ent_embeddings(t) return h_emb, r_emb, t_emb
[docs] def forward(self, h, r, t): h_e, r_e, t_e = self.embed(h, r, t) return -torch.sum(h_e*r_e*t_e, -1)
[docs] def get_reg(self, h, r, t, reg_type="F2"): h_e, r_e, t_e = self.embed(h, r, t) if reg_type.lower() == 'f2': regul_term = torch.mean(torch.sum(h_e ** 2, -1) + torch.sum(r_e ** 2, -1) + torch.sum(t_e ** 2, -1)) elif reg_type.lower() == 'n3': regul_term = torch.mean(torch.sum(h_e ** 3, -1) + torch.sum(r_e ** 3, -1) + torch.sum(t_e ** 3, -1)) else: raise NotImplementedError('Unknown regularizer type: %s' % reg_type) return self.lmbda*regul_term
[docs]class SimplE(PointwiseModel): """ `SimplE Embedding for Link Prediction in Knowledge Graphs`_ Args: config (object): Model configuration parameters. .. _SimplE Embedding for Link Prediction in Knowledge Graphs: https://papers.nips.cc/paper/7682-simple-embedding-for-link-prediction-in-knowledge-graphs.pdf """ def __init__(self, **kwargs): super(SimplE, self).__init__(self.__class__.__name__.lower()) param_list = ["tot_entity", "tot_relation", "hidden_size", "lmbda"] param_dict = self.load_params(param_list, kwargs) self.__dict__.update(param_dict) num_total_ent = self.tot_entity num_total_rel = self.tot_relation k = self.hidden_size self.tot_train_triples = kwargs['tot_train_triples'] self.batch_size = kwargs['batch_size'] self.ent_head_embeddings = NamedEmbedding("ent_head_embedding", num_total_ent, k) self.ent_tail_embeddings = NamedEmbedding("ent_tail_embedding", num_total_ent, k) self.rel_embeddings = NamedEmbedding("rel_embedding", num_total_rel, k) self.rel_inv_embeddings = NamedEmbedding("rel_inv_embedding", num_total_rel, k) nn.init.xavier_uniform_(self.ent_head_embeddings.weight) nn.init.xavier_uniform_(self.ent_tail_embeddings.weight) nn.init.xavier_uniform_(self.rel_embeddings.weight) nn.init.xavier_uniform_(self.rel_inv_embeddings.weight) self.parameter_list = [ self.ent_head_embeddings, self.ent_tail_embeddings, self.rel_embeddings, self.rel_inv_embeddings, ] self.loss = Criterion.pointwise_logistic
[docs] def embed(self, h, r, t): """Function to get the embedding value. Args: h (Tensor): Head entities ids. r (Tensor): Relation ids of the triple. t (Tensor): Tail entity ids of the triple. Returns: Tensors: Returns head, relation and tail embedding Tensors. """ emb_h1 = self.ent_head_embeddings(h) emb_h2 = self.ent_head_embeddings(t) emb_r1 = self.rel_embeddings(r) emb_r2 = self.rel_inv_embeddings(r) emb_t1 = self.ent_tail_embeddings(t) emb_t2 = self.ent_tail_embeddings(h) return emb_h1, emb_h2, emb_r1, emb_r2, emb_t1, emb_t2
[docs] def forward(self, h, r, t): h1_e, h2_e, r1_e, r2_e, t1_e, t2_e = self.embed(h, r, t) init = torch.sum(h1_e*r1_e*t1_e, 1) + torch.sum(h2_e*r2_e*t2_e, 1) / 2.0 return -torch.clamp(init, -20, 20)
[docs] def get_reg(self, h, r, t, reg_type="F2"): if reg_type.lower() == 'f2': regul_term = torch.mean(torch.sum(h.type(torch.FloatTensor) ** 2, -1) + torch.sum(r.type(torch.FloatTensor) ** 2, -1) + torch.sum(t.type(torch.FloatTensor) ** 2, -1)) elif reg_type.lower() == 'n3': regul_term = torch.mean(torch.sum(h.type(torch.FloatTensor) ** 3, -1) + torch.sum(r.type(torch.FloatTensor) ** 3, -1) + torch.sum(t.type(torch.FloatTensor) ** 3, -1)) else: raise NotImplementedError('Unknown regularizer type: %s' % reg_type) return self.lmbda * regul_term
[docs]class SimplE_ignr(SimplE): """ `SimplE Embedding for Link Prediction in Knowledge Graphs`_ Args: config (object): Model configuration parameters. .. _SimplE Embedding for Link Prediction in Knowledge Graphs: https://papers.nips.cc/paper/7682-simple-embedding-for-link-prediction-in-knowledge-graphs.pdf """ def __init__(self, **kwargs): super(SimplE_ignr, self).__init__(**kwargs) self.model_name = 'simple_ignr' self.loss = Criterion.pointwise_logistic
[docs] def embed(self, h, r, t): """Function to get the embedding value. Args: h (Tensor): Head entities ids. r (Tensor): Relation ids of the triple. t (Tensor): Tail entity ids of the triple. Returns: Tensors: Returns head, relation and tail embedding Tensors. """ emb_h = self._concat_selected_embeddings(self.ent_head_embeddings, h, self.ent_head_embeddings, t) emb_r = self._concat_selected_embeddings(self.rel_embeddings, r, self.rel_inv_embeddings, r) emb_t = self._concat_selected_embeddings(self.ent_tail_embeddings, t, self.ent_tail_embeddings, h) return emb_h, emb_r, emb_t
[docs] def forward(self, h, r, t): h_e, r_e, t_e = self.embed(h, r, t) init = torch.sum(h_e*r_e*t_e, 1) return -torch.clamp(init, -20, 20)
@staticmethod def _concat_selected_embeddings(e1, t1, e2, t2): return torch.cat([torch.index_select(e1.weight, 0, t1), torch.index_select(e2.weight, 0, t2)], 1)
[docs]class QuatE(PointwiseModel): """ `Quaternion Knowledge Graph Embeddings`_ Args: config (object): Model configuration parameters. .. _cheungdaven: https://github.com/cheungdaven/QuatE.git .. _Quaternion Knowledge Graph Embeddings: https://arxiv.org/abs/1904.10281 """ def __init__(self, **kwargs): super(QuatE, self).__init__(self.__class__.__name__.lower()) param_list = ["tot_entity", "tot_relation", "hidden_size", "lmbda"] param_dict = self.load_params(param_list, kwargs) self.__dict__.update(param_dict) num_total_ent = self.tot_entity num_total_rel = self.tot_relation k = self.hidden_size self.ent_s_embedding = NamedEmbedding("ent_s_embedding", num_total_ent, k) self.ent_x_embedding = NamedEmbedding("ent_x_embedding", num_total_ent, k) self.ent_y_embedding = NamedEmbedding("ent_y_embedding", num_total_ent, k) self.ent_z_embedding = NamedEmbedding("ent_z_embedding", num_total_ent, k) self.rel_s_embedding = NamedEmbedding("rel_s_embedding", num_total_rel, k) self.rel_x_embedding = NamedEmbedding("rel_x_embedding", num_total_rel, k) self.rel_y_embedding = NamedEmbedding("rel_y_embedding", num_total_rel, k) self.rel_z_embedding = NamedEmbedding("rel_z_embedding", num_total_rel, k) self.rel_w_embedding = NamedEmbedding("rel_w_embedding", num_total_rel, k) self.fc = nn.Linear(100, 50, bias=False) self.ent_dropout = nn.Dropout(0) self.rel_dropout = nn.Dropout(0) self.bn = nn.BatchNorm1d(k) r, i, j, k = QuatE._quaternion_init(self.tot_entity, self.hidden_size) r, i, j, k = torch.from_numpy(r), torch.from_numpy(i), torch.from_numpy(j), torch.from_numpy(k) self.ent_s_embedding.weight.data = r.type_as(self.ent_s_embedding.weight.data) self.ent_x_embedding.weight.data = i.type_as(self.ent_x_embedding.weight.data) self.ent_y_embedding.weight.data = j.type_as(self.ent_y_embedding.weight.data) self.ent_z_embedding.weight.data = k.type_as(self.ent_z_embedding.weight.data) s, x, y, z = QuatE._quaternion_init(self.tot_entity, self.hidden_size) s, x, y, z = torch.from_numpy(s), torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(z) self.rel_s_embedding.weight.data = s.type_as(self.rel_s_embedding.weight.data) self.rel_x_embedding.weight.data = x.type_as(self.rel_x_embedding.weight.data) self.rel_y_embedding.weight.data = y.type_as(self.rel_y_embedding.weight.data) self.rel_z_embedding.weight.data = z.type_as(self.rel_z_embedding.weight.data) nn.init.xavier_uniform_(self.ent_s_embedding.weight.data) nn.init.xavier_uniform_(self.ent_x_embedding.weight.data) nn.init.xavier_uniform_(self.ent_y_embedding.weight.data) nn.init.xavier_uniform_(self.ent_z_embedding.weight.data) nn.init.xavier_uniform_(self.rel_s_embedding.weight.data) nn.init.xavier_uniform_(self.rel_x_embedding.weight.data) nn.init.xavier_uniform_(self.rel_y_embedding.weight.data) nn.init.xavier_uniform_(self.rel_z_embedding.weight.data) nn.init.xavier_uniform_(self.rel_w_embedding.weight.data) self.parameter_list = [ self.ent_s_embedding, self.ent_x_embedding, self.ent_y_embedding, self.ent_z_embedding, self.rel_s_embedding, self.rel_x_embedding, self.rel_y_embedding, self.rel_z_embedding, self.rel_w_embedding, ] self.loss = Criterion.pointwise_logistic
[docs] def embed(self, h, r, t): s_emb_h = self.ent_s_embedding(h) x_emb_h = self.ent_x_embedding(h) y_emb_h = self.ent_y_embedding(h) z_emb_h = self.ent_z_embedding(h) s_emb_t = self.ent_s_embedding(t) x_emb_t = self.ent_x_embedding(t) y_emb_t = self.ent_y_embedding(t) z_emb_t = self.ent_z_embedding(t) s_emb_r = self.rel_s_embedding(r) x_emb_r = self.rel_x_embedding(r) y_emb_r = self.rel_y_embedding(r) z_emb_r = self.rel_z_embedding(r) return s_emb_h, x_emb_h, y_emb_h, z_emb_h, s_emb_t, x_emb_t, y_emb_t, z_emb_t, s_emb_r, x_emb_r, y_emb_r, z_emb_r
[docs] def forward(self, h, r, t): s_emb_h, x_emb_h, y_emb_h, z_emb_h, s_emb_t, x_emb_t, y_emb_t, z_emb_t, s_emb_r, x_emb_r, y_emb_r, z_emb_r = self.embed(h, r, t) denominator_b = torch.sqrt(s_emb_r ** 2 + x_emb_r ** 2 + y_emb_r ** 2 + z_emb_r ** 2) s_emb_r = s_emb_r / denominator_b x_emb_r = x_emb_r / denominator_b y_emb_r = y_emb_r / denominator_b z_emb_r = z_emb_r / denominator_b a = s_emb_h * s_emb_r - x_emb_h * x_emb_r - y_emb_h * y_emb_r - z_emb_h * z_emb_r b = s_emb_h * x_emb_r + s_emb_r * x_emb_h + y_emb_h * z_emb_r - y_emb_r * z_emb_h c = s_emb_h * y_emb_r + s_emb_r * y_emb_h + z_emb_h * x_emb_r - z_emb_r * x_emb_h d = s_emb_h * z_emb_r + s_emb_r * z_emb_h + x_emb_h * y_emb_r - x_emb_r * y_emb_h score_r = (a * s_emb_t + b * x_emb_t + c * y_emb_t + d * z_emb_t) return -torch.sum(score_r, -1)
[docs] def get_reg(self, h, r, t, reg_type='N3'): s_emb_h, x_emb_h, y_emb_h, z_emb_h, s_emb_t, x_emb_t, y_emb_t, z_emb_t, s_emb_r, x_emb_r, y_emb_r, z_emb_r = self.embed(h, r, t) if reg_type.lower() == 'f2': regul = (torch.mean(torch.abs(s_emb_h) ** 2) + torch.mean(torch.abs(x_emb_h) ** 2) + torch.mean(torch.abs(y_emb_h) ** 2) + torch.mean(torch.abs(z_emb_h) ** 2) + torch.mean(torch.abs(s_emb_t) ** 2) + torch.mean(torch.abs(x_emb_t) ** 2) + torch.mean(torch.abs(y_emb_t) ** 2) + torch.mean(torch.abs(z_emb_t) ** 2) ) regul2 = (torch.mean(torch.abs(s_emb_r) ** 2) + torch.mean(torch.abs(x_emb_r) ** 2) + torch.mean(torch.abs(y_emb_r) ** 2) + torch.mean(torch.abs(z_emb_r) ** 2)) elif reg_type.lower() == 'n3': regul = (torch.mean(torch.abs(s_emb_h) ** 3) + torch.mean(torch.abs(x_emb_h) ** 3) + torch.mean(torch.abs(y_emb_h) ** 3) + torch.mean(torch.abs(z_emb_h) ** 3) + torch.mean(torch.abs(s_emb_t) ** 3) + torch.mean(torch.abs(x_emb_t) ** 3) + torch.mean(torch.abs(y_emb_t) ** 3) + torch.mean(torch.abs(z_emb_t) ** 3) ) regul2 = (torch.mean(torch.abs(s_emb_r) ** 3) + torch.mean(torch.abs(x_emb_r) ** 3) + torch.mean(torch.abs(y_emb_r) ** 3) + torch.mean(torch.abs(z_emb_r) ** 3)) else: raise NotImplementedError('Unknown regularizer type: %s' % reg_type) return self.lmbda * (regul + regul2)
@staticmethod def _quaternion_init(in_features, out_features, criterion='he'): fan_in = in_features fan_out = out_features if criterion == 'glorot': s = 1. / np.sqrt(2 * (fan_in + fan_out)) elif criterion == 'he': s = 1. / np.sqrt(2 * fan_in) else: raise ValueError('Invalid criterion: ', criterion) rng = RandomState(123) kernel_shape = (in_features, out_features) number_of_weights = np.prod(kernel_shape) v_i = np.random.uniform(0.0, 1.0, number_of_weights) v_j = np.random.uniform(0.0, 1.0, number_of_weights) v_k = np.random.uniform(0.0, 1.0, number_of_weights) for i in range(0, number_of_weights): norm = np.sqrt(v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001 v_i[i] /= norm v_j[i] /= norm v_k[i] /= norm v_i = v_i.reshape(kernel_shape) v_j = v_j.reshape(kernel_shape) v_k = v_k.reshape(kernel_shape) modulus = rng.uniform(low=-s, high=s, size=kernel_shape) phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) weight_r = modulus * np.cos(phase) weight_i = modulus * v_i * np.sin(phase) weight_j = modulus * v_j * np.sin(phase) weight_k = modulus * v_k * np.sin(phase) return weight_r, weight_i, weight_j, weight_k
[docs]class OctonionE(PointwiseModel): """ `Quaternion Knowledge Graph Embeddings`_ Args: config (object): Model configuration parameters. .. _cheungdaven: https://github.com/cheungdaven/QuatE.git .. _Quaternion Knowledge Graph Embeddings: https://arxiv.org/abs/1904.10281 """ def __init__(self, **kwargs): super(OctonionE, self).__init__(self.__class__.__name__.lower()) param_list = ["tot_entity", "tot_relation", "hidden_size", "lmbda"] param_dict = self.load_params(param_list, kwargs) self.__dict__.update(param_dict) num_total_ent = self.tot_entity num_total_rel = self.tot_relation k = self.hidden_size self.ent_embedding_1 = NamedEmbedding("ent_embedding_1", num_total_ent, k) self.ent_embedding_2 = NamedEmbedding("ent_embedding_2", num_total_ent, k) self.ent_embedding_3 = NamedEmbedding("ent_embedding_3", num_total_ent, k) self.ent_embedding_4 = NamedEmbedding("ent_embedding_4", num_total_ent, k) self.ent_embedding_5 = NamedEmbedding("ent_embedding_5", num_total_ent, k) self.ent_embedding_6 = NamedEmbedding("ent_embedding_6", num_total_ent, k) self.ent_embedding_7 = NamedEmbedding("ent_embedding_7", num_total_ent, k) self.ent_embedding_8 = NamedEmbedding("ent_embedding_8", num_total_ent, k) self.rel_embedding_1 = NamedEmbedding("rel_embedding_1", num_total_rel, k) self.rel_embedding_2 = NamedEmbedding("rel_embedding_2", num_total_rel, k) self.rel_embedding_3 = NamedEmbedding("rel_embedding_3", num_total_rel, k) self.rel_embedding_4 = NamedEmbedding("rel_embedding_4", num_total_rel, k) self.rel_embedding_5 = NamedEmbedding("rel_embedding_5", num_total_rel, k) self.rel_embedding_6 = NamedEmbedding("rel_embedding_6", num_total_rel, k) self.rel_embedding_7 = NamedEmbedding("rel_embedding_7", num_total_rel, k) self.rel_embedding_8 = NamedEmbedding("rel_embedding_8", num_total_rel, k) self.rel_w_embedding = NamedEmbedding("rel_w_embedding", num_total_rel, k) nn.init.xavier_uniform_(self.ent_embedding_1.weight.data) nn.init.xavier_uniform_(self.ent_embedding_2.weight.data) nn.init.xavier_uniform_(self.ent_embedding_3.weight.data) nn.init.xavier_uniform_(self.ent_embedding_4.weight.data) nn.init.xavier_uniform_(self.ent_embedding_5.weight.data) nn.init.xavier_uniform_(self.ent_embedding_6.weight.data) nn.init.xavier_uniform_(self.ent_embedding_7.weight.data) nn.init.xavier_uniform_(self.ent_embedding_8.weight.data) nn.init.xavier_uniform_(self.rel_embedding_1.weight.data) nn.init.xavier_uniform_(self.rel_embedding_2.weight.data) nn.init.xavier_uniform_(self.rel_embedding_3.weight.data) nn.init.xavier_uniform_(self.rel_embedding_4.weight.data) nn.init.xavier_uniform_(self.rel_embedding_5.weight.data) nn.init.xavier_uniform_(self.rel_embedding_6.weight.data) nn.init.xavier_uniform_(self.rel_embedding_7.weight.data) nn.init.xavier_uniform_(self.rel_embedding_8.weight.data) nn.init.xavier_uniform_(self.rel_w_embedding.weight.data) self.parameter_list = [ self.ent_embedding_1, self.ent_embedding_2, self.ent_embedding_3, self.ent_embedding_4, self.ent_embedding_5, self.ent_embedding_6, self.ent_embedding_7, self.ent_embedding_8, self.rel_embedding_1, self.rel_embedding_2, self.rel_embedding_3, self.rel_embedding_4, self.rel_embedding_5, self.rel_embedding_6, self.rel_embedding_7, self.rel_embedding_8, self.rel_w_embedding, ] self.loss = Criterion.pointwise_logistic
[docs] def embed(self, h, r, t): e_1_h = self.ent_embedding_1(h) e_2_h = self.ent_embedding_2(h) e_3_h = self.ent_embedding_3(h) e_4_h = self.ent_embedding_4(h) e_5_h = self.ent_embedding_5(h) e_6_h = self.ent_embedding_6(h) e_7_h = self.ent_embedding_7(h) e_8_h = self.ent_embedding_8(h) e_1_t = self.ent_embedding_1(t) e_2_t = self.ent_embedding_2(t) e_3_t = self.ent_embedding_3(t) e_4_t = self.ent_embedding_4(t) e_5_t = self.ent_embedding_5(t) e_6_t = self.ent_embedding_6(t) e_7_t = self.ent_embedding_7(t) e_8_t = self.ent_embedding_8(t) r_1 = self.rel_embedding_1(r) r_2 = self.rel_embedding_2(r) r_3 = self.rel_embedding_3(r) r_4 = self.rel_embedding_4(r) r_5 = self.rel_embedding_5(r) r_6 = self.rel_embedding_6(r) r_7 = self.rel_embedding_7(r) r_8 = self.rel_embedding_8(r) return e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, \ e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, \ r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8
[docs] def forward(self, h, r, t): e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, \ e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, \ r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self.embed(h, r, t) r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = OctonionE._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8) o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = OctonionE._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8) score_r = (o_1 * e_1_t + o_2 * e_2_t + o_3 * e_3_t + o_4 * e_4_t + o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t) return -torch.sum(score_r, -1)
[docs] def get_reg(self, h, r, t, reg_type='N3'): e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, \ e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, \ r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self.embed(h, r, t) if reg_type.lower() == 'f2': regul = (torch.mean(torch.abs(e_1_h) ** 2) + torch.mean(torch.abs(e_2_h) ** 2) + torch.mean(torch.abs(e_3_h) ** 2) + torch.mean(torch.abs(e_4_h) ** 2) + torch.mean(torch.abs(e_5_h) ** 2) + torch.mean(torch.abs(e_6_h) ** 2) + torch.mean(torch.abs(e_7_h) ** 2) + torch.mean(torch.abs(e_8_h) ** 2) + torch.mean(torch.abs(e_1_t) ** 2) + torch.mean(torch.abs(e_2_t) ** 2) + torch.mean(torch.abs(e_3_t) ** 2) + torch.mean(torch.abs(e_4_t) ** 2) + torch.mean(torch.abs(e_5_t) ** 2) + torch.mean(torch.abs(e_6_t) ** 2) + torch.mean(torch.abs(e_7_t) ** 2) + torch.mean(torch.abs(e_8_t) ** 2) ) regul2 = (torch.mean(torch.abs(r_1) ** 2) + torch.mean(torch.abs(r_2) ** 2) + torch.mean(torch.abs(r_3) ** 2) + torch.mean(torch.abs(r_4) ** 2) + torch.mean(torch.abs(r_5) ** 2) + torch.mean(torch.abs(r_6) ** 2) + torch.mean(torch.abs(r_7) ** 2) + torch.mean(torch.abs(r_8) ** 2)) elif reg_type.lower() == 'n3': regul = (torch.mean(torch.abs(e_1_h) ** 3) + torch.mean(torch.abs(e_2_h) ** 3) + torch.mean(torch.abs(e_3_h) ** 3) + torch.mean(torch.abs(e_4_h) ** 3) + torch.mean(torch.abs(e_5_h) ** 3) + torch.mean(torch.abs(e_6_h) ** 3) + torch.mean(torch.abs(e_7_h) ** 3) + torch.mean(torch.abs(e_8_h) ** 3) + torch.mean(torch.abs(e_1_t) ** 3) + torch.mean(torch.abs(e_2_t) ** 3) + torch.mean(torch.abs(e_3_t) ** 3) + torch.mean(torch.abs(e_4_t) ** 3) + torch.mean(torch.abs(e_5_t) ** 3) + torch.mean(torch.abs(e_6_t) ** 3) + torch.mean(torch.abs(e_7_t) ** 3) + torch.mean(torch.abs(e_8_t) ** 3) ) regul2 = (torch.mean(torch.abs(r_1) ** 3) + torch.mean(torch.abs(r_2) ** 3) + torch.mean(torch.abs(r_3) ** 3) + torch.mean(torch.abs(r_4) ** 3) + torch.mean(torch.abs(r_5) ** 3) + torch.mean(torch.abs(r_6) ** 3) + torch.mean(torch.abs(r_7) ** 3) + torch.mean(torch.abs(r_8) ** 3)) else: raise NotImplementedError('Unknown regularizer type: %s' % reg_type) return self.lmbda * (regul + regul2)
@staticmethod def _qmult(s_a, x_a, y_a, z_a, s_b, x_b, y_b, z_b): a = s_a * s_b - x_a * x_b - y_a * y_b - z_a * z_b b = s_a * x_b + s_b * x_a + y_a * z_b - y_b * z_a c = s_a * y_b + s_b * y_a + z_a * x_b - z_b * x_a d = s_a * z_b + s_b * z_a + x_a * y_b - x_b * y_a return a, b, c, d @staticmethod def _qstar(a, b, c, d): return a, -b, -c, -d @staticmethod def _omult(a_1, a_2, a_3, a_4, b_1, b_2, b_3, b_4, c_1, c_2, c_3, c_4, d_1, d_2, d_3, d_4): d_1_star, d_2_star, d_3_star, d_4_star = OctonionE._qstar(d_1, d_2, d_3, d_4) c_1_star, c_2_star, c_3_star, c_4_star = OctonionE._qstar(c_1, c_2, c_3, c_4) o_1, o_2, o_3, o_4 = OctonionE._qmult(a_1, a_2, a_3, a_4, c_1, c_2, c_3, c_4) o_1s, o_2s, o_3s, o_4s = OctonionE._qmult(d_1_star, d_2_star, d_3_star, d_4_star, b_1, b_2, b_3, b_4) o_5, o_6, o_7, o_8 = OctonionE._qmult(d_1, d_2, d_3, d_4, a_1, a_2, a_3, a_4) o_5s, o_6s, o_7s, o_8s = OctonionE._qmult(b_1, b_2, b_3, b_4, c_1_star, c_2_star, c_3_star, c_4_star) return o_1 - o_1s, o_2 - o_2s, o_3 - o_3s, o_4 - o_4s, \ o_5 + o_5s, o_6 + o_6s, o_7 + o_7s, o_8 + o_8s @staticmethod def _onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8): denominator = torch.sqrt(r_1 ** 2 + r_2 ** 2 + r_3 ** 2 + r_4 ** 2 + r_5 ** 2 + r_6 ** 2 + r_7 ** 2 + r_8 ** 2) r_1 = r_1 / denominator r_2 = r_2 / denominator r_3 = r_3 / denominator r_4 = r_4 / denominator r_5 = r_5 / denominator r_6 = r_6 / denominator r_7 = r_7 / denominator r_8 = r_8 / denominator return r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8
[docs]class MuRP(PointwiseModel): """ `Multi-relational Poincaré Graph Embeddings`_ Args: config (object): Model configuration parameters. .. _Multi-relational Poincaré Graph Embeddings: https://arxiv.org/abs/1905.09791 """ def __init__(self, **kwargs): super(MuRP, self).__init__(self.__class__.__name__.lower()) param_list = ["tot_entity", "tot_relation", "hidden_size", "lmbda"] param_dict = self.load_params(param_list, kwargs) self.__dict__.update(param_dict) k = self.hidden_size self.device = kwargs["device"] self.ent_embeddings = NamedEmbedding("ent_embedding", self.tot_entity, k, padding_idx=0) self.ent_embeddings.weight.data = ( 1e-3 * torch.randn((self.tot_entity, k), dtype=torch.double, device=self.device)) self.rel_embeddings = NamedEmbedding("rel_embedding", self.tot_relation, k, padding_idx=0) self.rel_embeddings.weight.data = ( 1e-3 * torch.randn((self.tot_relation, k), dtype=torch.double, device=self.device)) self.wu = nn.Parameter( torch.tensor(np.random.uniform(-1, 1, (self.tot_relation, k)), dtype=torch.double, requires_grad=True, device=self.device)) self.bs = nn.Parameter( torch.zeros(self.tot_entity, dtype=torch.double, requires_grad=True, device=self.device)) self.bo = nn.Parameter( torch.zeros(self.tot_entity, dtype=torch.double, requires_grad=True, device=self.device)) self.parameter_list = [ self.ent_embeddings, self.rel_embeddings, ] self.loss = Criterion.pointwise_bce
[docs] def embed(self, h, r, t): """Function to get the embedding value. Args: h (Tensor): Head entities ids. r (Tensor): Relation ids of the triple. t (Tensor): Tail entity ids of the triple. Returns: Tensors: Returns head, relation and tail embedding Tensors. """ h_emb = self.ent_embeddings(h) r_emb = self.rel_embeddings(r) t_emb = self.ent_embeddings(t) return h_emb, r_emb, t_emb
[docs] def forward(self, h, r, t): return self._poincare_forward(h, r, t)
def predict_tail_rank(self, h, r, topk): del topk _, rank = torch.sort(self.forward(h, r, torch.LongTensor(list(range(self.tot_entity))).to(self.device))) return rank def predict_head_rank(self, t, r, topk): del topk _, rank = torch.sort(self.forward(torch.LongTensor(list(range(self.tot_entity))).to(self.device), r, t)) return rank def predict_rel_rank(self, h, t, topk): del topk _, rank = torch.sort(self.forward(h, torch.LongTensor(list(range(self.tot_relation))).to(self.device), t)) return rank def _poincare_forward(self, h, r, t): h_emb, r_emb, t_emb = self.embed(h, r, t) ru = self.wu[r] h_emb = torch.where(torch.norm(h_emb, 2, dim=-1, keepdim=True) >= 1, h_emb / (torch.norm(h_emb, 2, dim=-1, keepdim=True) - 1e-5), h_emb) t_emb = torch.where(torch.norm(t_emb, 2, dim=-1, keepdim=True) >= 1, t_emb / (torch.norm(t_emb, 2, dim=-1, keepdim=True) - 1e-5), t_emb) r_emb = torch.where(torch.norm(r_emb, 2, dim=-1, keepdim=True) >= 1, r_emb / (torch.norm(r_emb, 2, dim=-1, keepdim=True) - 1e-5), r_emb) u_e = self._p_log_map(h_emb) u_w = u_e * ru u_m = self._p_exp_map(u_w) v_m = self._p_sum(t_emb, r_emb) u_m = torch.where(torch.norm(u_m, 2, dim=-1, keepdim=True) >= 1, u_m / (torch.norm(u_m, 2, dim=-1, keepdim=True) - 1e-5), u_m) v_m = torch.where(torch.norm(v_m, 2, dim=-1, keepdim=True) >= 1, v_m / (torch.norm(v_m, 2, dim=-1, keepdim=True) - 1e-5), v_m) sqdist = (2. * self._arsech( torch.clamp(torch.norm(self._p_sum(-u_m, v_m), 2, dim=-1), 1e-10, 1 - 1e-5))) ** 2 return -(sqdist - self.bs[h] - self.bo[t]) def _euclidean_forward(self, h, r, t): h_emb, r_emb, t_emb = self.embed(h, r, t) ru = self.wu[r] u_w = h_emb * ru sqdist = torch.sum(torch.pow(u_w - (t_emb + r_emb), 2), dim=-1) return -(sqdist - self.bs[h] - self.bo[t]) @staticmethod def _arsech(x): return torch.log((1 + torch.sqrt(1 - x.pow(2))) / x) @staticmethod def _p_exp_map(v): normv = torch.clamp(torch.norm(v, 2, dim=-1, keepdim=True), min=1e-10) return (1 / torch.cosh(normv)) * v / normv @staticmethod def _p_log_map(v): normv = torch.clamp(torch.norm(v, 2, dim=-1, keepdim=True), 1e-10, 1 - 1e-5) return MuRP._arsech(normv) * v / normv @staticmethod def _p_sum(x, y): sqxnorm = torch.clamp(torch.sum(x * x, dim=-1, keepdim=True), 0, 1 - 1e-5) sqynorm = torch.clamp(torch.sum(y * y, dim=-1, keepdim=True), 0, 1 - 1e-5) dotxy = torch.sum(x * y, dim=-1, keepdim=True) numerator = (1 + 2 * dotxy + sqynorm) * x + (1 - sqxnorm) * y denominator = 1 + 2 * dotxy + sqxnorm * sqynorm return numerator / denominator