Source code for pykg2vec.test.test_kg

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module is for testing unit functions of KnowledgeGraph
"""

import os
import pytest
from pathlib import Path
from pykg2vec.data.kgcontroller import KnowledgeGraph
from pykg2vec.data.datasets import KnownDataset


[docs]@pytest.mark.parametrize("dataset_name", [ "freebase15k", "deeplearning50a", "wordnet18", "wordnet18_rr", "yago3_10", "freebase15k_237", "kinship", "nations", "umls", "nell_995", ]) def test_known_datasets(dataset_name): """Function to test the the knowledge graph parse for Freebase.""" knowledge_graph = KnowledgeGraph(dataset=dataset_name) knowledge_graph.force_prepare_data() assert len(knowledge_graph.dump()) == 9 assert knowledge_graph.is_cache_exists() kg_metadata = knowledge_graph.dataset.read_metadata() assert kg_metadata.tot_triple > 0 assert kg_metadata.tot_valid_triples > 0 assert kg_metadata.tot_test_triples > 0 assert kg_metadata.tot_train_triples > 0 assert kg_metadata.tot_relation > 0 assert kg_metadata.tot_entity > 0 assert len(knowledge_graph.read_cache_data('triplets_train')) > 0 assert len(knowledge_graph.read_cache_data('triplets_test')) > 0 assert len(knowledge_graph.read_cache_data('triplets_valid')) > 0 assert len(knowledge_graph.read_cache_data('hr_t')) > 0 assert len(knowledge_graph.read_cache_data('tr_h')) > 0 assert len(knowledge_graph.read_cache_data('idx2entity')) > 0 assert len(knowledge_graph.read_cache_data('idx2relation')) > 0 assert len(knowledge_graph.read_cache_data('entity2idx')) > 0 assert len(knowledge_graph.read_cache_data('relation2idx')) > 0
def test_userdefined_dataset(): custom_dataset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resource', 'custom_dataset') knowledge_graph = KnowledgeGraph(dataset="userdefineddataset", custom_dataset_path=custom_dataset_path) knowledge_graph.prepare_data() knowledge_graph.dump() knowledge_graph.read_cache_data('triplets_train') knowledge_graph.read_cache_data('triplets_test') knowledge_graph.read_cache_data('triplets_valid') knowledge_graph.read_cache_data('hr_t') knowledge_graph.read_cache_data('tr_h') knowledge_graph.read_cache_data('idx2entity') knowledge_graph.read_cache_data('idx2relation') knowledge_graph.read_cache_data('entity2idx') knowledge_graph.read_cache_data('relation2idx') knowledge_graph.dataset.read_metadata() knowledge_graph.dataset.dump() assert knowledge_graph.kg_meta.tot_train_triples == 1 assert knowledge_graph.kg_meta.tot_test_triples == 1 assert knowledge_graph.kg_meta.tot_valid_triples == 1 assert knowledge_graph.kg_meta.tot_entity == 6 assert knowledge_graph.kg_meta.tot_relation == 3 @pytest.mark.parametrize('file_name, new_ext', [ ('dataset.tar.gz', 'tgz'), ('dataset.tgz', 'tgz'), ('dataset.zip', 'zip'), ]) def test_extract_compressed_dataset(file_name, new_ext): url = Path(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resource', file_name)).absolute().as_uri() dataset_name = 'test_dataset_%s' % file_name.replace('.', '_') dataset = KnownDataset(dataset_name, url, 'userdefineddataset-') dataset_dir = os.path.join(dataset.dataset_home_path, dataset_name) dataset_files = os.listdir(dataset_dir) assert len(dataset_files) == 4 assert dataset_name + '.' + new_ext in dataset_files assert 'userdefineddataset-train.txt' in dataset_files assert 'userdefineddataset-test.txt' in dataset_files assert 'userdefineddataset-valid.txt' in dataset_files