Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions demos/python/SpikingTE/SpikeTrainAISTesting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from jpype import *
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

RUN_SIMULATION = True

NUM_REPS = 10
NUM_OBSERVATIONS = 1
jarLocation = os.path.join(os.getcwd(), "infodynamics.jar")
if (not(os.path.isfile(jarLocation))):
exit("infodynamics.jar not found (expected at " + os.path.abspath(jarLocation) + ") - are you running from demos/python?")

nu_nx_ratios = [0.5, 1.0, 2.0, 5.0]
embedding_strings = ["1", "1,2"]
knn_values = [5]
spike_counts = np.logspace(2, 5, 4, dtype=int)

results_csv = 'ais_results.csv'
figure_path = 'ais_vs_spikes_ratios.png'

columns = ['embedding', 'nu_nx_ratio', 'k', 'spike_count', 'mean_ais', 'std_ais']
results_df = pd.DataFrame(columns=columns)

if RUN_SIMULATION:
startJVM(getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jarLocation)
aisCalcClass = JPackage("infodynamics.measures.spiking.integration").ActiveInformationStorageCalculatorSpikingIntegration

print("Testing Active Information Storage on spike trains with varying spike counts")
aisCalc = aisCalcClass()
aisCalc.setProperty("knns", "1")
aisCalc.setProperty("PAST_INTERVALS", "1")
aisCalc.setProperty("DO_JITTERED_SAMPLING", "true")
aisCalc.setProperty("NUM_SAMPLES_MULTIPLIER", "2.0")
aisCalc.setProperty("NORM_TYPE", "MAX_NORM")

for embedding in embedding_strings:
aisCalc.setProperty("PAST_INTERVALS", embedding)
print(f"\n=== Testing embedding {embedding} ===")

for row_idx, ratio in enumerate(nu_nx_ratios):
aisCalc.setProperty("NUM_SAMPLES_MULTIPLIER", str(ratio))

for col_idx, k in enumerate(knn_values):
print(f" N_U/N_X = {ratio}, k = {k}")
aisCalc.setProperty("knns", str(k))

mean_ais = np.zeros(len(spike_counts))
std_ais = np.zeros(len(spike_counts))

for i, num_spikes in enumerate(spike_counts):
rep_results = np.zeros(NUM_REPS)

for rep in range(NUM_REPS):
aisCalc.startAddObservations()
for _ in range(NUM_OBSERVATIONS):
spikeArray = num_spikes * np.random.random(num_spikes)
spikeArray.sort()
aisCalc.addObservations(JArray(JDouble, 1)(spikeArray))
aisCalc.finaliseAddObservations()

rep_results[rep] = aisCalc.computeAverageLocalOfObservations()

mean_ais[i] = np.mean(rep_results)
std_ais[i] = np.std(rep_results)

results_df = results_df._append({
'embedding': embedding,
'nu_nx_ratio': ratio,
'k': k,
'spike_count': num_spikes,
'mean_ais': mean_ais[i],
'std_ais': std_ais[i]
}, ignore_index=True)

results_df.to_csv(results_csv, index=False)
print(f"Simulation results saved to {results_csv}")
else:
try:
results_df = pd.read_csv(results_csv)
print(f"Loaded saved results from {results_csv}")
except FileNotFoundError:
print(f"Error: {results_csv} not found. Set RUN_SIMULATION to True to generate results first.")
exit(1)

check_df = results_df[(results_df['nu_nx_ratio'] == 1.0) & (results_df['k'] == 5)]
sanity = check_df.groupby('embedding')['mean_ais'].mean()
print("\nMean AIS (should stay ~0 for Poisson) by embedding:")
print(sanity.to_string())

fig, axes = plt.subplots(len(nu_nx_ratios),
len(knn_values),
figsize=(10, 12),
sharex=True,
sharey=True)

if axes.ndim == 1:
axes = axes.reshape(len(nu_nx_ratios), len(knn_values))

for row_idx, ratio in enumerate(nu_nx_ratios):
for col_idx, k in enumerate(knn_values):
ax = axes[row_idx, col_idx]

data = results_df[(results_df['embedding'] == "1,2") &
(results_df['nu_nx_ratio'] == ratio) &
(results_df['k'] == k)]

if not data.empty:
plot_spike_counts = np.asarray(data['spike_count'].values, dtype=float)
plot_mean_ais = np.asarray(data['mean_ais'].values, dtype=float)
plot_std_ais = np.asarray(data['std_ais'].values, dtype=float)

ax.set_xscale('log')
ax.set_xlim(10**2, 10**5)
ax.set_ylim(-0.2, 1)

ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5, alpha=0.7)

ax.plot(plot_spike_counts, plot_mean_ais, 'b-', linewidth=2)
ax.fill_between(plot_spike_counts,
plot_mean_ais - plot_std_ais,
plot_mean_ais + plot_std_ais,
alpha=0.3, color='blue')

if col_idx == 0:
ax.set_ylabel('AIS (nats/second)', fontsize=11)
ax.set_title(rf'$N_U/N_X = {ratio}$', loc='left', fontsize=12)
if row_idx == 0:
ax.set_title(f'k = {k}', loc='center', fontsize=12)

for col_idx in range(len(knn_values)):
axes[-1, col_idx].set_xlabel('Number of Events', fontsize=11)

plt.tight_layout()

plt.savefig(figure_path)
print(f"Figure saved to {figure_path}")

plt.show()
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Java Information Dynamics Toolkit (JIDT)
* Copyright (C) 2012, Joseph T. Lizier
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package infodynamics.measures.spiking;

import infodynamics.utils.EmpiricalMeasurementDistribution;

/**
* Interface for implementations of Active Information Storage (AIS) calculators
* on spike trains or other event-based data.
*
* <p>AIS is the mutual information between a process' past and its next state.
* For spiking processes, this measures how much information about future spiking
* is contained in the past spiking patterns.</p>
*
* <p>Usage is intended to follow this paradigm:</p>
* <ol>
* <li>Construct the calculator via an implementing class constructor;</li>
* <li>Set properties using {@link #setProperty(String, String)};</li>
* <li>Initialize the calculator using {@link #initialise()} or
* {@link #initialise(int)};</li>
* <li>Provide the observations/samples for the calculator to set up the PDFs,
* using one or more calls to sets of {@link #setObservations(double[])} or
* {@link #addObservations(double[])} or
* {@link #startAddObservations()}, followed by multiple calls to
* {@link #addObservations(double[])} and then
* {@link #finaliseAddObservations()};</li>
* <li>Compute the required quantities, being one or more of:
* <ul>
* <li>the average AIS: {@link #computeAverageLocalOfObservations()};</li>
* <li>the local AIS values for these samples: {@link #computeLocalOfPreviousObservations()}</li>
*
* <li>the distribution of AIS values under the null hypothesis
* of no relationship between past and future:
* {@link #computeSignificance(int)} or
* {@link #computeSignificance(int, double)}.</li>
* </ul>
* </li>
* </ol>
*
* @author Michael Fang (fangmichael33@gmail.com)
*
*/
public interface ActiveInformationStorageCalculatorSpiking {



/**
* Property name for a comma-separated list of integers representing
* the past intervals to use in the embedding.
*/
public static final String PAST_INTERVALS_PROP_NAME = "PAST_INTERVALS";



/**
* Initialise the calculator for (re-)use, with the existing
* (or default) values of parameters.
*
* @throws Exception
*/
public void initialise() throws Exception;

/**
* Initialise the calculator for (re-)use, with some parameters
* supplied here rather than in later method calls.
*
* @param k Length of past history to consider (i.e. embedding length)
* @throws Exception
*/
public void initialise(int k) throws Exception;

/**
* Set properties for the calculator.
* New property values are not guaranteed to take effect until the next call
* to an initialise method.
*
* <p>Valid property names, and what their
* values should represent, include:</p>
* <ul>
* <li>{@link #PAST_INTERVALS_PROP_NAME} -- a comma-separated list
* of integers representing the past intervals to use in the embedding.</li>
* <li>Any other properties defined by the implementing class.</li>
* </ul>
*
* @param propertyName name of the property to set
* @param propertyValue value of the property to set
* @throws Exception for invalid property values
*/
public void setProperty(String propertyName, String propertyValue) throws Exception;

/**
* Get property values for the calculator.
*
* <p>Valid property names, and what their
* values should represent, are the same as those for
* {@link #setProperty(String, String)}</p>
*
* @param propertyName name of the property
* @return the value of the property
* @throws Exception for invalid property values
*/
public String getProperty(String propertyName) throws Exception;

/**
* Sets a single set of observations for the calculator to use.
* Cannot be called once {@link #startAddObservations()} has been called,
* and cannot be called after {@link #addObservations(double[])} has been
* called.
*
* @param observations time-series array of spike times
* @throws Exception
*/
public void setObservations(double[] observations) throws Exception;

/**
* Signal that we will add in the observations for calculating the PDFs
* from several disjoint time-series or trials.
*
* @throws Exception
*/
public void startAddObservations() throws Exception;

/**
* Add observations for the PDFs for a single time-series.
*
* @param observations time-series array of spike times
* @throws Exception
*/
public void addObservations(double[] observations) throws Exception;

/**
* Signal that we have finished adding in the observations.
*
* @throws Exception
*/
public void finaliseAddObservations() throws Exception;

/**
* Returns whether more than one time-series has been added
* to the calculator (either via {@link #setObservations(double[])}
* or via {@link #addObservations(double[])})
*
* @return true if more than one time-series has been supplied
*/
public boolean getAddedMoreThanOneObservationSet();

/**
* Compute the average AIS from the previously-supplied samples.
*
* @return the estimate of the AIS
* @throws Exception
*/
public double computeAverageLocalOfObservations() throws Exception;

/**
* This interface serves to indicate the return type of {@link #computeLocalOfPreviousObservations()}
* as each child implementation will return something specific
*/
public interface SpikingLocalInformationValues {
// Left empty intentionally
}

/**
* Compute the local AIS values for the previously-supplied samples.
*
* @return an object containing a representation of the local AIS values
* @throws Exception
*/
public SpikingLocalInformationValues computeLocalOfPreviousObservations() throws Exception;

public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, double estimatedValue) throws Exception;

public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, double estimatedValue, long randomSeed) throws Exception;

/**
* Set whether to display debug messages or not.
*
* @param debug display debug messages if true
*/
public void setDebug(boolean debug);

/**
* Return the AIS last calculated in a call to {@link #computeAverageLocalOfObservations()}
* or {@link #computeLocalOfPreviousObservations()} after the previous
* {@link #initialise()} call.
*
* @return the last computed average AIS value
*/
public double getLastAverage();
}
Loading