import numpy as np
from tqdm.auto import tqdm
import multiprocessing as mp
from scipy.stats import circmean
from . import base
from .. import data as aodata
from .. import precondition
from . import accllr
[docs]def get_acq_ch_near_stimulation_site(stim_site, stim_layout='Opto32', electrode_layout='ECoG244',
dist_thr=1, return_idx=False):
'''
Get acquisition channels near a stimulation site. Use :func:`~aopy.data.load_chmap` to find the
channels for the stimulation and electrode sites.
Note:
This function returns channel
numbers, which sometimes are 1-indexed. Set the return_idx flag to True to get the channel
indices as well.
Args:
stim_site (int): stimulation site (must match a channel in the stim_layout)
stim_layout (str): layout of stimulation sites, e.g. 'Opto32'. See
:func:`~aopy.data.load_chmap` for options.
electrode_layout (str): layout of electrodes, e.g. 'ECoG244'. See
:func:`~aopy.data.load_chmap` for options.
dist_thr (float or tuple (min, max), optional): threshold for distance from stimulation site (in
the same units as the electrode layout, typically mm). If a tuple, the distance must be greater
than or equal to min and less than max. Default is 1.
return_idx (bool, optional): if True, return the channel indices as well. Default is False.
Returns:
acq_ch: np.ndarray, acquisition channels near stimulation site
| or **(acq_ch, idx) tuple**, if return_idx is True
'''
elec_pos, acq_ch, _ = aodata.load_chmap(electrode_layout)
stim_pos, stim_ch, _ = aodata.load_chmap(stim_layout)
stim_site_pos = stim_pos[stim_ch == stim_site]
if stim_site_pos.size == 0:
raise ValueError(f"stim_site site {stim_site} not found in layout {stim_layout}")
dist = np.linalg.norm(elec_pos - stim_site_pos, axis=1)
if np.size(dist_thr) == 2:
idx = (dist < dist_thr[1]) & (dist >= dist_thr[0])
elif np.size(dist_thr) == 1:
idx = dist < dist_thr
else:
raise ValueError("dist_thr must be a float or tuple (min, max) of floats")
if return_idx:
return acq_ch[idx], np.where(idx)[0]
else:
return acq_ch[idx]
[docs]def prepare_erp(erp, samplerate, time_before, time_after,
window_nullcond, window_altcond, zscore=False, ref=False):
'''
Prepare data for connectivity analysis. Given event-related potentials, extracts a sub-window
and normalizes to a baseline null condition. Optionally re-references the data.
Args:
erp ((nt, nch, ntr) array): trial-aligned data
samplerate (float): sampling rate of the erps
time_before (float): time before event in the erp (in seconds)
time_after (float): time after event in the erp (in seconds)
window_nullcond ((2,) tuple of float): desired (start, end) of nullcond (in seconds)
window_altcond ((2,) tuple of float): desired (start, end) of altcond (in seconds)
zscore (bool, optional): if True, z-score the data. Default is False.
ref (bool, optional): if True, re-reference the data. Default is False.
Returns:
((nt_before_new, nch, ntr) array): alternative condition sub-window of the prepared erp
'''
assert len(window_nullcond) == 2 and window_nullcond[1] > window_nullcond[0]
assert len(window_altcond) == 2 and window_altcond[1] > window_altcond[0]
assert window_nullcond[0] >= -time_before
assert window_altcond[1] <= time_after
# Find start and end indices
altcond_start = int((time_before+window_altcond[0])*samplerate)-1
altcond_dur = window_altcond[1] - window_altcond[0]
altcond_end = altcond_start + int(altcond_dur*samplerate)
nullcond_start = int((time_before+window_nullcond[0])*samplerate)
nullcond_dur = window_nullcond[1] - window_nullcond[0]
nullcond_end = nullcond_start+int(nullcond_dur*samplerate)
# Extract data
data_altcond = erp[altcond_start:altcond_end,:,:].copy()
data_nullcond = erp[nullcond_start:nullcond_end,:,:].copy()
# Make each trial zero-mean for both stim and baseline
baseline = np.mean(data_nullcond, axis=0)
data_altcond -= baseline
# Z-score the data
if zscore:
data_altcond /= np.std(data_nullcond, axis=0)
# Re-reference the data
if ref:
data_altcond = data_altcond - np.mean(data_altcond, axis=1, keepdims=True) # mean across channels
return data_altcond
[docs]def calc_connectivity_coh(data_altcond_source, data_altcond_probe, n, p, k,
samplerate, step, fk=250, pad=2,
imaginary=True, average=True):
'''
Calculate the average time-frequency cohereogram between multiple source
and probe channels. Iterates through every possible pair (order doesn't matter)
of source and probe channels and calculates the coherence between them. Optionally
returns the average across all pairs. This function is called by
:func:`calc_connectivity_map_coh` to calculate the coherence between a single channel
and multiple channels around the stimulation site. No re-referencing is done here,
if you want to re-reference the data, do it before calling this function.
Args:
data_altcond_source (nt, n_source, ntrial): source erp data
data_altcond_probe (nt, n_probe, ntrial): probe erp data
n (float): window length in seconds
p (float): standardized half bandwidth in hz
k (int): number of DPSS tapers to use
fs (float): sampling rate in Hz.
step (float): window step size in seconds.
fk (float, optional): frequency range to return in Hz ([0, fk]). Default is fs/2.
pad (int, optional): padding factor for the FFT. This should be 1 or a multiple of 2.
For nt=500, if pad=1, we pad the FFT to 512 points.
If pad=2, we pad the FFT to 1024 points.
If pad=4, we pad the FFT to 2024 points.
Default is 2.
imaginary (bool, optional): if True, compute imaginary coherence.
average: bool, whether to average the coherence across all pairs. angles are averaged
using circular statistics. Default is True.
Returns:
tuple: tuple containing:
| **f (n_freq):** frequency axis
| **t (nt):** time axis
| **coh (list of (n_freq,nt)):** magnitude squared coherence or imaginary coherence
(0 <= coh <= 1) between the pairs
| **angle ((list of n_freq,nt)):** list of phase difference (in radians) between the pairs
(optional output, 0 <= angle <= 2*pi, how much does the probe lead the source)
| **pair (list of tuples):** list of channel pairs
| or **(freqs, time, coh, angle)** tuple**, if average is True
'''
data_altcond = np.concatenate((data_altcond_probe, data_altcond_source), axis=1)
stim_coh = []
stim_angle = []
pair = []
n_source = data_altcond_source.shape[1]
n_probe = data_altcond_probe.shape[1]
for source_idx in range(n_source):
for probe_idx in range(n_probe):
ch_pair = np.array([probe_idx, n_probe+source_idx]) # for [probe, source], angle ≈ phase(probe) - phase(source)
if set(ch_pair) in pair: # skip the reciprocal pairs
continue
freqs, time, coh, angle = base.calc_mt_tfcoh(data_altcond, ch_pair, n, p, k, samplerate,
step=step, fk=fk, pad=pad, imaginary=imaginary,
ref=False, return_angle=True)
angle = (angle + 2*np.pi) % (2*np.pi) # wrap the angle from [-pi, pi] to [0, 2*pi]
stim_coh.append(coh)
stim_angle.append(angle)
pair.append(set(ch_pair))
if average:
return freqs, time, np.mean(stim_coh, axis=0), circmean(stim_angle, axis=0)
else:
# Remove the offset in pair
pair = [(tuple(p)[0], tuple(p)[1]-n_probe) for p in pair]
return freqs, time, stim_coh, stim_angle, pair
[docs]def calc_connectivity_map_coh(erp, samplerate, time_before, time_after, stim_ch_idx, window=None,
n=0.06, step=0.03, bw=25, zscore=False, ref=True, parallel=False,
verbose=True, imaginary=True, **kwargs):
'''
Map of coherence at every channel to the given stimulation channels. Input ERP data must include
at least `n` seconds before and after events. Coherence is averaged across stimulation channels
if multiple are given.
Args:
erp ((nt, nch, ntr) array): trial-aligned data
samplerate (float): sampling rate of the erp
time_before (float): time included before events in the erp (in seconds)
time_after (float): time included after events in the erp (in seconds)
stim_ch_idx (list of 0-indexed int): stimulation channel indices (where you want coherence to be calculated from)
window (2-tuple, optional): time window for the coherence calculation in seconds. If None, a single (0, n) window
timestep will be used and the step parameter will be ignored. Default None.
n (float): window length in seconds for the coherence calculation (default 0.06 s).
step (float): window step size in seconds for the coherence calculation (default 0.03 s).
bw (float): bandwidth for multitaper filter (default 25).
zscore (bool): z-score flag (default False).
ref (bool): re-referencing flag (default True).
parallel (bool or mp.pool.Pool): whether to use parallel processing. Can optionally be a pool object
to use an existing pool. If True, a new pool is created with the number of CPUs available. If False,
computation is done serially (the default).
verbose (bool): if True, show a progress bar (default True).
imaginary (bool): if True, compute imaginary coherence (the default).
Returns:
tuple: tuple containing:
| **freqs (n_freq):** frequency axis
| **time (nt):** time axis
| **coh_all (n_freq, nt, nch):** magnitude squared coherence or imaginary coherence
(0 <= coh <= 1) between the pairs at each channel
| **angle_all (n_freq, nt, nch):** phase difference (in radians) between the pairs
at each channel
Note:
This is not the most efficient way to compute pairwise coherence since we end up repeating
the same calculations for each channel multiple times. Maybe a future enhancement. See the
implementation in the package `spectral_connectivity` for a more time-efficient (but memory-
inefficient) algorithm.
Examples:
Create a grid of channels with mostly noise but two channels have 50 Hz sine waves
.. code-block:: python
grid_size = 3
nch = grid_size**2
T = 1
fs = 1000
nt = int(T*fs)
ntr = 2
time = np.linspace(0,T,nt)
data = np.random.normal(0, 0.1, (nt,nch,ntr)) # start with noise
stim_ch_idx = 0
data[:,stim_ch_idx,0] += np.sin(2*np.pi*50*time) # 50 Hz sine
data[:,stim_ch_idx,1] += np.sin(2*np.pi*50*time)
data[500:,4,0] += np.cos(2*np.pi*50*time[500:]) # 50 Hz cosine in second half of trial
data[500:,4,1] += np.cos(2*np.pi*50*time[500:])
n = 0.25
w = 10
step = 0.25
f, t, coh_all, angle_all = aopy.analysis.connectivity.calc_connectivity_map_coh(data, fs, 0.5, 0.5, [stim_ch_idx],
window=(-n, n), n=n, bw=w, step=step, ref=False)
self.assertEqual(coh_all.shape, angle_all.shape)
bands = [(40, 60), (100, 250)]
x, y = np.meshgrid(np.arange(grid_size), np.arange(grid_size))
elec_pos = np.zeros((nch,2))
elec_pos[:,0] = x.reshape(-1)
elec_pos[:,1] = y.reshape(-1)
aopy.visualization.plot_tf_map_grid(f, t, coh_all, bands, elec_pos, clim=(0,1), interp_grid=None,
cmap='viridis')
.. image:: _images/connectivity_map_coh.png
'''
assert erp.ndim == 3, "ERP data must be 3D (nt, nch, ntr)"
assert time_before >= n, "time_before must be greater than or equal to n"
assert time_after >= n, "time_after must be greater than or equal to n"
n, p, k = precondition.convert_taper_parameters(n, bw)
if verbose:
print(f"using {k} tapers for tfcoh")
if window is None:
window = (0, n)
nullcond_window = (-time_before, 0)
data_altcond = prepare_erp(
erp, samplerate, time_before, time_after, nullcond_window, window,
zscore=zscore, ref=ref
)
# Create a parallel pool if requested
pool = None
if parallel is True: # create a parallel pool
pool = mp.Pool(min(mp.cpu_count()//2, erp.shape[1]))
elif type(parallel) is mp.pool.Pool: # use an existing pool
pool = parallel
# Calculate coherence for each channel
kwargs['imaginary'] = imaginary
coh_all = []
angle_all = []
freqs = None
time = None
if pool:
# call apply_async() without callback
result_objects = [pool.apply_async(calc_connectivity_coh,
args=(data_altcond[:,[ch],:], data_altcond[:,stim_ch_idx,:], n, p, k, samplerate, step),
kwds=kwargs)
for ch in range(erp.shape[1])]
# result_objects is a list of pool.ApplyResult objects
if verbose:
results = list(tqdm((r.get() for r in result_objects), total=erp.shape[1], leave=False))
else:
results = [r.get() for r in result_objects]
freqs, time, coh_all, angle_all = zip(*results)
freqs = freqs[0]
time = time[0]
if parallel is True:
pool.close()
else:
if verbose:
iterator = tqdm(range(erp.shape[1]), leave=False)
else:
iterator = range(erp.shape[1])
for ch in iterator:
freqs, time, coh_avg, angle_avg = calc_connectivity_coh(
data_altcond[:,[ch],:], data_altcond[:,stim_ch_idx,:], n, p, k, samplerate, step,
**kwargs
)
coh_all.append(coh_avg)
angle_all.append(angle_avg)
# Move time to the first axis and channels to the end
coh_all = np.array(coh_all).transpose(1,2,0)
angle_all = np.array(angle_all).transpose(1,2,0)
return freqs, time+window[0], coh_all, angle_all