|
3 | 3 | import dill
|
4 | 4 | import pyspi
|
5 | 5 | import numpy as np
|
6 |
| -from copy import deepcopy |
7 |
| -import warnings |
| 6 | + |
| 7 | + |
8 | 8 |
|
9 | 9 |
|
10 | 10 | ############# Fixtures and helper functions #########
|
11 | 11 |
|
12 |
| -def load_benchmark_calcs(): |
13 |
| - benchmark_calcs = dict() |
14 |
| - calcs = ['calc_standard_normal.pkl'] # follow this naming convention -> calc_{name}.pkl |
15 |
| - for calc in calcs: |
16 |
| - # extract the calculator name from the filename |
17 |
| - calc_name = calc[len("calc_"):-len(".pkl")] |
18 |
| - |
19 |
| - # Load the calculator |
20 |
| - with open(f"tests/{calc}", "rb") as f: |
21 |
| - loaded_calc = dill.load(f) |
22 |
| - benchmark_calcs[calc_name] = loaded_calc |
| 12 | +def load_benchmark_tables(): |
| 13 | + """Function to load the mean and standard deviation tables for each MPI.""" |
| 14 | + table_fname = 'CML7_benchmark_tables.pkl' |
| 15 | + with open(f"tests/{table_fname}", "rb") as f: |
| 16 | + loaded_tables = dill.load(f) |
23 | 17 |
|
24 |
| - return benchmark_calcs |
25 |
| - |
26 |
| -def load_benchmark_datasets(): |
27 |
| - benchmark_datasets = dict() |
28 |
| - dataset_names = ['standard_normal.npy'] |
29 |
| - for dname in dataset_names: |
30 |
| - dataset = np.load(f"pyspi/data/{dname}") |
31 |
| - dataset = dataset.T |
32 |
| - benchmark_datasets[dname.strip('.npy')] = dataset |
| 18 | + return loaded_tables |
33 | 19 |
|
34 |
| - return benchmark_datasets |
| 20 | +def load_benchmark_dataset(): |
| 21 | + dataset_fname = 'cml7.npy' |
| 22 | + dataset = np.load(f"pyspi/data/{dataset_fname}").T |
| 23 | + return dataset |
35 | 24 |
|
36 | 25 | def compute_new_tables():
|
37 | 26 | """Compute new tables using the same benchmark dataset(s)."""
|
38 |
| - benchmark_datasets = load_benchmark_datasets() |
39 |
| - # Compute new tables on the benchmark datasets |
40 |
| - new_calcs = dict() |
41 |
| - |
42 |
| - calc_base = Calculator() # create base calculator object |
43 |
| - |
44 |
| - for dataset in benchmark_datasets.keys(): |
45 |
| - calc = deepcopy(calc_base) # make a copy of the base calculator |
46 |
| - calc.load_dataset(dataset=benchmark_datasets[dataset]) |
47 |
| - calc.compute() |
48 |
| - new_calcs[dataset] = calc |
| 27 | + benchmark_dataset = load_benchmark_dataset() |
| 28 | + # Compute new tables on the benchmark dataset |
| 29 | + np.random.seed(42) |
| 30 | + calc = Calculator(dataset=benchmark_dataset) |
| 31 | + calc.compute() |
| 32 | + table_dict = dict() |
| 33 | + for spi in calc.spis: |
| 34 | + table_dict[spi] = calc.table[spi] |
49 | 35 |
|
50 |
| - return new_calcs |
| 36 | + return table_dict |
51 | 37 |
|
52 | 38 | def generate_SPI_test_params():
|
53 |
| - """Generate combinations of calculator, dataset and SPI for the fixture.""" |
54 |
| - benchmark_calcs = load_benchmark_calcs() |
55 |
| - new_calcs = compute_new_tables() |
| 39 | + """Function to generate combinations of benchmark table, |
| 40 | + new table for each MPI""" |
| 41 | + benchmark_tables = load_benchmark_tables() |
| 42 | + new_tables = compute_new_tables() |
56 | 43 | params = []
|
57 |
| - for calc_name, benchmark_calc in benchmark_calcs.items(): |
58 |
| - spi_dict = benchmark_calc.spis |
59 |
| - for spi_est in spi_dict.keys(): |
60 |
| - params.append((calc_name, spi_est, benchmark_calc.table[spi_est], new_calcs[calc_name].table[spi_est])) |
| 44 | + calc = Calculator() |
| 45 | + spis = list(calc.spis.keys()) |
| 46 | + spi_ob = list(calc.spis.values()) |
| 47 | + for spi_est, spi_ob in zip(spis, spi_ob): |
| 48 | + params.append((spi_est, spi_ob, benchmark_tables[spi_est], new_tables[spi_est].to_numpy())) |
61 | 49 |
|
62 | 50 | return params
|
63 | 51 |
|
64 | 52 | params = generate_SPI_test_params()
|
65 | 53 | def pytest_generate_tests(metafunc):
|
66 | 54 | """Create a hook to generate parameter combinations for parameterised test"""
|
67 |
| - if "calc_name" in metafunc.fixturenames: |
68 |
| - metafunc.parametrize("calc_name,est,mpi_benchmark,mpi_new", params) |
| 55 | + if "est" in metafunc.fixturenames: |
| 56 | + metafunc.parametrize("est, est_ob, mpi_benchmark,mpi_new", params) |
| 57 | + |
69 | 58 |
|
70 |
| -def test_mpi(calc_name, est, mpi_benchmark, mpi_new): |
| 59 | +def test_mpi(est, est_ob, mpi_benchmark, mpi_new, spi_warning_logger): |
71 | 60 | """Run the benchmarking tests."""
|
72 |
| - |
73 |
| - """First check to see if any SPIs are 'broken', as would be the case if |
74 |
| - the benchmark table contains values for certain SPIs whereas the new table for the same |
75 |
| - SPI does not (NaN). Also, if all values are NaNs for one SPI and not for the same SPI in the |
76 |
| - newly computed table. """ |
| 61 | + zscore_threshold = 2 # 2 sigma |
77 | 62 |
|
78 |
| - mismatched_nans = (mpi_benchmark.isna() != mpi_new.isna()) |
79 |
| - assert not mismatched_nans.any().any(), f"SPI: {est} | Dataset: {calc_name}. Mismatched NaNs." |
| 63 | + # separate the the mean and std. dev tables for the benchmark |
| 64 | + mean_table = mpi_benchmark['mean'] |
| 65 | + std_table = mpi_benchmark['std'] |
80 | 66 |
|
81 |
| - # check that the shapes are equal |
82 |
| - assert mpi_benchmark.shape == mpi_new.shape, f"SPI: {est}| Dataset: {calc_name}. Different table shapes. " |
| 67 | + # check std stable for zeros and impute with smallest non-zero value |
| 68 | + min_nonzero_std = np.nanmin(std_table[std_table > 0]) |
| 69 | + std_table[std_table == 0] = min_nonzero_std |
| 70 | + |
83 | 71 |
|
84 |
| - # Now quantify the difference between tables (if a diff exists) |
85 |
| - epsilon = np.finfo(float).eps |
| 72 | + # check that the shapes are equal |
| 73 | + assert mean_table.shape == mpi_new.shape, f"SPI: {est}| Different table shapes. " |
| 74 | + |
| 75 | + # convert NaNs to zeros before proeceeding - this will take care of diagonal and any null outputs |
| 76 | + mpi_new = np.nan_to_num(mpi_new) |
| 77 | + mpi_mean = np.nan_to_num(mean_table) |
| 78 | + |
| 79 | + # check if matrix is symmetric (undirected SPI) for num exceed correction |
| 80 | + isSymmetric = "undirected" in est_ob.labels |
| 81 | + |
| 82 | + # get the module name for easy reference |
| 83 | + module_name = est_ob.__module__.split(".")[-1] |
| 84 | + |
| 85 | + if (mpi_new == mpi_mean).all() == False: |
| 86 | + # tables are not equivalent, quantify the difference by z-scoring. |
| 87 | + diff = abs(mpi_new - mpi_mean) |
| 88 | + zscores = diff/std_table |
| 89 | + idxs_greater_than_thresh = np.argwhere(zscores > zscore_threshold) |
| 90 | + if len(idxs_greater_than_thresh) > 0: |
| 91 | + sigs = list() |
| 92 | + for idx in idxs_greater_than_thresh: |
| 93 | + sigs.append(zscores[idx[0], idx[1]]) |
| 94 | + # get the max |
| 95 | + max_z = max(sigs) |
| 96 | + # number of interactions |
| 97 | + num_iteractions = (mpi_new.shape[0] * mpi_new.shape[1]) - mpi_new.shape[0] |
| 98 | + # count exceedances |
| 99 | + num_exceed = len(sigs) |
| 100 | + if isSymmetric: |
| 101 | + # number of unique exceedences is half |
| 102 | + num_exceed /= 2 |
| 103 | + num_iteractions /= 2 |
| 104 | + |
| 105 | + spi_warning_logger(est, module_name, max_z, int(num_exceed), int(num_iteractions)) |
86 | 106 |
|
87 |
| - if not mpi_benchmark.equals(mpi_new): |
88 |
| - diff = abs(mpi_benchmark - mpi_new) |
89 |
| - max_diff = diff.max().max() |
90 |
| - if max_diff > epsilon: |
91 |
| - warnings.warn(f"SPI: {est} | Dataset: {calc_name} | Max difference: {max_diff}") |
92 | 107 |
|
93 | 108 |
|
94 | 109 |
|
95 |
| - |
|
0 commit comments