Source code for aopy.analysis.celltype

# celltype.py
# 
# Cell-type specific computations, e.g. classification by spike width

import numpy as np
from sklearn.mixture import GaussianMixture

from .base import find_outliers, get_pca_dimensions, interpolate_extremum_poly2

'''
Cell type classification analysis
'''
[docs]def classify_cells_spike_width(waveform_data, samplerate, std_threshold=3, pca_varthresh=0.75, min_wfs=10): ''' Calculates waveform width and classifies units into putative exciatory and inhibitory cell types based on pulse width. Units with lower spike width are considered inhibitory cells (label: 0) and higher spike width are considered excitatory cells (label: 1) The pulse width is defined as the time between the waveform trough to the waveform peak. (trough-to-peak time) Assumes all waveforms are recorded for the same number of time points. This function conducts the following processing steps: | **1. ** For each unit, project each waveform into the top PCs. Number of PCs determined by 'pca_varthresh' | **2. ** For each unit, remove outlier spikes. Outlier threhsold determined by 'std_threshold'. If the number of waveforms is less than 'min_wf', no waveforms are removed. | **3. ** For each unit, average remaining waveforms. | **4. ** For each unit, calculate spike width using a local polynomial interpolation. | **5. ** Use a gaussian mixture model to classify all units Args: waveform_data (nunit long list of (nt x nwaveforms) arrays): Waveforms of each unit. Each element of the list is a 2D array for each unit. Each 2D array contains the timeseries of all recorded waveforms for a given unit. samplerate (float): sampling rate of the points in each waveform. std_threshold (float): For outlier removal. The maximum number of standard deviations (in PC space) away from the mean a given waveform is allowed to be. Defaults to 3 pca_varthresh (float): Variance threshold for determining the number of dimensions to project spiking data onto. Defaults to 0.75. min_wfs (int): Minimum number of waveform samples required to perform outlier detection. Returns: tuple: A tuple containing | **TTP (nunit):** Spike width of each unit. [us] | **unit_labels (nunit):** Label of each unit. 0: low spike width (inhibitory), 1: high spike width (excitatory) | **avg_wfs (nunit, nt):** Average waveform of accepted waveforms for each unit | **sss_unitid (1D):*** Unit index of spikes with a lower number of spikes than allowed by 'min_wfs' ''' TTP = [] sss_unitid = [] # Get data size parameters. nt, _ = waveform_data[0].shape nunits = len(waveform_data) # Initialize array for average waveforms avg_wfs = np.zeros((nt, nunits)) # Iterate through all units for iunit in range(nunits): iwfdata = waveform_data[iunit] # shape (nt, nunit) - waveforms for each unit # Use PCA and kmeans to remove outliers if there are enough data points if iwfdata.shape[1] >= min_wfs: # Use each time point as a feature and each spike as a sample. _, _, iwfdata_proj = get_pca_dimensions(iwfdata.T, max_dims=None, VAF=pca_varthresh, project_data=True) good_wf_idx, _ = find_outliers(iwfdata_proj, std_threshold) else: good_wf_idx = np.arange(iwfdata.shape[1]) sss_unitid.append(iunit) iwfdata_good = iwfdata[:,good_wf_idx] # Average good waveforms iwfdata_good_avg = np.mean(iwfdata_good, axis = 1) avg_wfs[:,iunit] = iwfdata_good_avg # Calculate 1st order TTP approximation troughidx_1st, peakidx_1st = find_trough_peak_idx(iwfdata_good_avg) # Interpolate peaks with a parabolic fit troughidx_2nd, _, _ = interpolate_extremum_poly2(troughidx_1st, iwfdata_good_avg, extrap_peaks=False) peakidx_2nd, _, _ = interpolate_extremum_poly2(peakidx_1st, iwfdata_good_avg, extrap_peaks=False) # Calculate 2nd order TTP approximation TTP.append(1e6*(peakidx_2nd - troughidx_2nd)/samplerate) gmm_proc = GaussianMixture(n_components = 2, random_state = 0).fit(np.array(TTP).reshape(-1, 1)) unit_labels = gmm_proc.predict(np.array(TTP).reshape(-1, 1)) # Ensure lowest TTP unit is inhibitory (0) minttpidx = np.argmin(TTP) if unit_labels[minttpidx] == 1: unit_labels = 1 - unit_labels return TTP, unit_labels, avg_wfs, sss_unitid
[docs]def find_trough_peak_idx(unit_data): ''' This function calculates the trough-to-peak time at the index level (0th order) by finding the minimum value of the waveform, and identifying that as the trough index. To calculate the peak index, this function finds the index corresponding to the first negative derivative of the waveform. If there is no next negative derivative of the waveform, this function returns the last index as the peak time. Args: unit_data (nt, nch): Array of waveforms (Can be a 1D array with dimension nt) Returns: tuple: A tuple containing | **troughidx (nch):** Array of indices corresponding to the trough time for each channel | **peakidx (nch):** Array of indices corresponding ot the peak time for each channel. ''' # Handle condition where the input data is a 1D array if len(unit_data.shape) == 1: troughidx = np.argmin(unit_data) wfdecreaseidx = np.where(np.diff(unit_data[troughidx:])<0) if np.size(wfdecreaseidx) == 0: peakidx = len(unit_data)-1 else: peakidx = np.min(wfdecreaseidx) + troughidx # Handle 2D input data array else: troughidx = np.argmin(unit_data, axis = 0) peakidx = np.empty(troughidx.shape) for trialidx in range(len(peakidx)): wfdecreaseidx = np.where(np.diff(unit_data[troughidx[trialidx]:,trialidx])<0) # Handle the condition where there is no negative derivative. if np.size(wfdecreaseidx) == 0: peakidx[trialidx] = len(unit_data[:,trialidx])-1 else: peakidx[trialidx] = np.min(wfdecreaseidx) + troughidx[trialidx] return troughidx, peakidx
[docs]def get_unit_spiking_mean_variance(spiking_data): ''' This function calculates the mean spiking count and the spiking count variance in spiking data across trials for each unit. Args: spiking_data (ntime, nunits, ntr): Input spiking data Returns: Tuple: A tuple containing | **unit_mean:** The mean spike counts for each unit across the input time | **unit_variance:** The spike count variance for each unit across the input time ''' counts = np.sum(spiking_data, axis=1) # Counts has the shape (nunits, ntr) unit_mean = np.mean(counts, axis=1) # Averge the counts for each unit across all trials unit_variance = np.var(counts, axis=1) # Calculate the count variance for each unit across all trials return unit_mean, unit_variance