Skip to content

Commit 3f9c0b5

Browse files
authored
Merge pull request #57 from DynamicsAndNeuralSystems/jmoo2880-add-new-spi-testing
New benchmarking dataset and dependency updates
2 parents 9da4258 + 409208a commit 3f9c0b5

File tree

11 files changed

+135
-71
lines changed

11 files changed

+135
-71
lines changed

.github/workflows/run_unit_tests.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ jobs:
2525
python -m pip install --upgrade pip
2626
pip install -r requirements.txt
2727
pip install .
28-
pip install pandas==1.3.3 numpy==1.22.0
2928
- name: Run pyspi calculator unit tests
3029
run: |
3130
pytest -v ./tests/test_calc.py

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<p align="center">
22
<picture>
3-
<source srcset="img/pyspi_logo_dark.png" media="(prefers-color-scheme: dark)">
3+
<source srcset="img/pyspi_logo_darkmode.png" media="(prefers-color-scheme: dark)">
44
<img src="img/pyspi_logo.png" alt="pyspi logo" height="200"/>
55
</picture>
66
</p>

img/pyspi_logo_dark.png

-59.7 KB
Binary file not shown.

img/pyspi_logo_darkmode.png

83 KB
Loading

pyspi/data/cml7.npy

5.59 KB
Binary file not shown.

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
pytest
2-
scikit-learn==0.24.1
2+
scikit-learn==1.0.1
33
scipy==1.7.3
44
numpy>=1.21.1
5-
pandas>=1.3.3
5+
pandas==1.5.0
66
statsmodels==0.12.1
77
pyyaml==5.4
88
tqdm==4.50.2

setup.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
# http://www.diveintopython3.net/packaging.html
44
# https://pypi.python.org/pypi?:action=list_classifiers
55

6-
with open('README.md') as file:
6+
with open('README.md', 'r', encoding='utf-8') as file:
77
long_description = file.read()
88

99

1010
install_requires = [
11-
'scikit-learn==0.24.1',
11+
'scikit-learn==1.0.1',
1212
'scipy==1.7.3',
1313
'numpy>=1.21.1',
14-
'pandas>=1.3.3',
14+
'pandas==1.5.0',
1515
'statsmodels==0.12.1',
1616
'pyyaml==5.4',
1717
'tqdm==4.50.2',
@@ -59,9 +59,10 @@
5959
'lib/PhiToolbox/utility/Gauss/logdet.m',
6060
'data/cml.npy',
6161
'data/forex.npy',
62-
'data/standard_normal.npy']},
62+
'data/standard_normal.npy',
63+
'data/cml7.npy']},
6364
include_package_data=True,
64-
version='0.4.1',
65+
version='0.4.2',
6566
description='Library for pairwise analysis of time series data.',
6667
author='Oliver M. Cliff',
6768
author_email='oliver.m.cliff@gmail.com',

tests/CML7_benchmark_tables.pkl

484 KB
Binary file not shown.

tests/calc_standard_normal.pkl

-1.1 MB
Binary file not shown.

tests/conftest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
@pytest.fixture(scope="session")
4+
def spi_warning_logger(request):
5+
warnings_log = list()
6+
7+
def add_warning(spi, module_name, max_z, num_exceed, num_iteractions):
8+
warnings_log.append((spi, module_name, max_z, num_exceed, num_iteractions))
9+
10+
request.session.spi_warnings = warnings_log
11+
return add_warning
12+
13+
def pytest_sessionfinish(session, exitstatus):
14+
# retrieve the spi warnings from the session object
15+
spi_warnings = getattr(session, 'spi_warnings', [])
16+
17+
# styling
18+
header_line = "=" * 80
19+
content_line = "-" * 80
20+
footer_line = "=" * 80
21+
header = " SPI BENCHMARKING SUMMARY"
22+
footer = f" Session completed with exit status: {exitstatus} "
23+
padded_header = f"{header:^80}"
24+
padded_footer = f"{footer:^80}"
25+
26+
print("\n")
27+
print(header_line)
28+
print(padded_header)
29+
print(header_line)
30+
31+
# print problematic SPIs in table format
32+
if spi_warnings:
33+
print(f"\nDetected {len(spi_warnings)} SPI(s) with outputs exceeding the specified 2 sigma threshold.\n")
34+
35+
# table header
36+
print(f"{'SPI':<25}{'Cat':<10}{'Max ZSc.':>10}{'# Exceed. Pairs':>20}{'Unq. Pairs':>15}")
37+
print(content_line)
38+
39+
# table content
40+
for est, module_name, max_z, num_exceed, num_iteractions in spi_warnings:
41+
# add special character for v.large zscores
42+
error = ""
43+
if max_z > 10:
44+
error = " **"
45+
print(f"{est+error:<25}{module_name:<10}{max_z:>10.4g}{num_exceed:>15}{num_iteractions:>20}")
46+
else:
47+
print("\n\nNo SPIs exceeded the sigma threshold.\n")
48+
49+
print(footer_line)
50+
print(padded_footer)

tests/test_SPIs.py

Lines changed: 76 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,93 +3,107 @@
33
import dill
44
import pyspi
55
import numpy as np
6-
from copy import deepcopy
7-
import warnings
6+
7+
88

99

1010
############# Fixtures and helper functions #########
1111

12-
def load_benchmark_calcs():
13-
benchmark_calcs = dict()
14-
calcs = ['calc_standard_normal.pkl'] # follow this naming convention -> calc_{name}.pkl
15-
for calc in calcs:
16-
# extract the calculator name from the filename
17-
calc_name = calc[len("calc_"):-len(".pkl")]
18-
19-
# Load the calculator
20-
with open(f"tests/{calc}", "rb") as f:
21-
loaded_calc = dill.load(f)
22-
benchmark_calcs[calc_name] = loaded_calc
12+
def load_benchmark_tables():
13+
"""Function to load the mean and standard deviation tables for each MPI."""
14+
table_fname = 'CML7_benchmark_tables.pkl'
15+
with open(f"tests/{table_fname}", "rb") as f:
16+
loaded_tables = dill.load(f)
2317

24-
return benchmark_calcs
25-
26-
def load_benchmark_datasets():
27-
benchmark_datasets = dict()
28-
dataset_names = ['standard_normal.npy']
29-
for dname in dataset_names:
30-
dataset = np.load(f"pyspi/data/{dname}")
31-
dataset = dataset.T
32-
benchmark_datasets[dname.strip('.npy')] = dataset
18+
return loaded_tables
3319

34-
return benchmark_datasets
20+
def load_benchmark_dataset():
21+
dataset_fname = 'cml7.npy'
22+
dataset = np.load(f"pyspi/data/{dataset_fname}").T
23+
return dataset
3524

3625
def compute_new_tables():
3726
"""Compute new tables using the same benchmark dataset(s)."""
38-
benchmark_datasets = load_benchmark_datasets()
39-
# Compute new tables on the benchmark datasets
40-
new_calcs = dict()
41-
42-
calc_base = Calculator() # create base calculator object
43-
44-
for dataset in benchmark_datasets.keys():
45-
calc = deepcopy(calc_base) # make a copy of the base calculator
46-
calc.load_dataset(dataset=benchmark_datasets[dataset])
47-
calc.compute()
48-
new_calcs[dataset] = calc
27+
benchmark_dataset = load_benchmark_dataset()
28+
# Compute new tables on the benchmark dataset
29+
np.random.seed(42)
30+
calc = Calculator(dataset=benchmark_dataset)
31+
calc.compute()
32+
table_dict = dict()
33+
for spi in calc.spis:
34+
table_dict[spi] = calc.table[spi]
4935

50-
return new_calcs
36+
return table_dict
5137

5238
def generate_SPI_test_params():
53-
"""Generate combinations of calculator, dataset and SPI for the fixture."""
54-
benchmark_calcs = load_benchmark_calcs()
55-
new_calcs = compute_new_tables()
39+
"""Function to generate combinations of benchmark table,
40+
new table for each MPI"""
41+
benchmark_tables = load_benchmark_tables()
42+
new_tables = compute_new_tables()
5643
params = []
57-
for calc_name, benchmark_calc in benchmark_calcs.items():
58-
spi_dict = benchmark_calc.spis
59-
for spi_est in spi_dict.keys():
60-
params.append((calc_name, spi_est, benchmark_calc.table[spi_est], new_calcs[calc_name].table[spi_est]))
44+
calc = Calculator()
45+
spis = list(calc.spis.keys())
46+
spi_ob = list(calc.spis.values())
47+
for spi_est, spi_ob in zip(spis, spi_ob):
48+
params.append((spi_est, spi_ob, benchmark_tables[spi_est], new_tables[spi_est].to_numpy()))
6149

6250
return params
6351

6452
params = generate_SPI_test_params()
6553
def pytest_generate_tests(metafunc):
6654
"""Create a hook to generate parameter combinations for parameterised test"""
67-
if "calc_name" in metafunc.fixturenames:
68-
metafunc.parametrize("calc_name,est,mpi_benchmark,mpi_new", params)
55+
if "est" in metafunc.fixturenames:
56+
metafunc.parametrize("est, est_ob, mpi_benchmark,mpi_new", params)
57+
6958

70-
def test_mpi(calc_name, est, mpi_benchmark, mpi_new):
59+
def test_mpi(est, est_ob, mpi_benchmark, mpi_new, spi_warning_logger):
7160
"""Run the benchmarking tests."""
72-
73-
"""First check to see if any SPIs are 'broken', as would be the case if
74-
the benchmark table contains values for certain SPIs whereas the new table for the same
75-
SPI does not (NaN). Also, if all values are NaNs for one SPI and not for the same SPI in the
76-
newly computed table. """
61+
zscore_threshold = 2 # 2 sigma
7762

78-
mismatched_nans = (mpi_benchmark.isna() != mpi_new.isna())
79-
assert not mismatched_nans.any().any(), f"SPI: {est} | Dataset: {calc_name}. Mismatched NaNs."
63+
# separate the the mean and std. dev tables for the benchmark
64+
mean_table = mpi_benchmark['mean']
65+
std_table = mpi_benchmark['std']
8066

81-
# check that the shapes are equal
82-
assert mpi_benchmark.shape == mpi_new.shape, f"SPI: {est}| Dataset: {calc_name}. Different table shapes. "
67+
# check std stable for zeros and impute with smallest non-zero value
68+
min_nonzero_std = np.nanmin(std_table[std_table > 0])
69+
std_table[std_table == 0] = min_nonzero_std
70+
8371

84-
# Now quantify the difference between tables (if a diff exists)
85-
epsilon = np.finfo(float).eps
72+
# check that the shapes are equal
73+
assert mean_table.shape == mpi_new.shape, f"SPI: {est}| Different table shapes. "
74+
75+
# convert NaNs to zeros before proeceeding - this will take care of diagonal and any null outputs
76+
mpi_new = np.nan_to_num(mpi_new)
77+
mpi_mean = np.nan_to_num(mean_table)
78+
79+
# check if matrix is symmetric (undirected SPI) for num exceed correction
80+
isSymmetric = "undirected" in est_ob.labels
81+
82+
# get the module name for easy reference
83+
module_name = est_ob.__module__.split(".")[-1]
84+
85+
if (mpi_new == mpi_mean).all() == False:
86+
# tables are not equivalent, quantify the difference by z-scoring.
87+
diff = abs(mpi_new - mpi_mean)
88+
zscores = diff/std_table
89+
idxs_greater_than_thresh = np.argwhere(zscores > zscore_threshold)
90+
if len(idxs_greater_than_thresh) > 0:
91+
sigs = list()
92+
for idx in idxs_greater_than_thresh:
93+
sigs.append(zscores[idx[0], idx[1]])
94+
# get the max
95+
max_z = max(sigs)
96+
# number of interactions
97+
num_iteractions = (mpi_new.shape[0] * mpi_new.shape[1]) - mpi_new.shape[0]
98+
# count exceedances
99+
num_exceed = len(sigs)
100+
if isSymmetric:
101+
# number of unique exceedences is half
102+
num_exceed /= 2
103+
num_iteractions /= 2
104+
105+
spi_warning_logger(est, module_name, max_z, int(num_exceed), int(num_iteractions))
86106

87-
if not mpi_benchmark.equals(mpi_new):
88-
diff = abs(mpi_benchmark - mpi_new)
89-
max_diff = diff.max().max()
90-
if max_diff > epsilon:
91-
warnings.warn(f"SPI: {est} | Dataset: {calc_name} | Max difference: {max_diff}")
92107

93108

94109

95-

0 commit comments

Comments
 (0)