Source code for aopy.preproc.quality

# quality.py
# 
# Functions for assessing data quality

import copy
import traceback

import numpy as np
import numpy.linalg as npla
import scipy.signal as sps
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from .. import precondition
from .. import analysis
from .. import utils
from .. import visualization
from .. import data as aodata

[docs]def detect_bad_exp_data(preproc_dir, subjects, ids, dates): ''' Identifies preprocessed experiment data files that contain errors. Detected errors include: preprocessed data could not be loaded, preprocessed data is missing event timestamps (likely the result of missing ecube data), and preprocessed data contains events that do not match bmi3d_events (likely the result of inaccurately sent digital events). The detected files should be re-preprocessed or otherwise debugged before being used in analysis! Args: preproc_dir (str): base directory where the files live subjects (list of str): Subject name for each recording ids (list of int): Block number of Task entry object for each recording dates (list of str): Date for each recording Returns: list of lists: entries (subject, date, id) of each preprocessed experiment data file that was identified to have errors ''' bad_entries = [] entries = list(zip(subjects, dates, ids)) for subject, date, te in tqdm(entries): # Load data from bmi3d hdf try: exp_data, exp_metadata = aodata.load_preproc_exp_data(preproc_dir, subject, te, date) except: print(f"Entry {subject} {date} {te} could not be loaded.") traceback.print_exc() bad_entries.append([subject,date,te]) continue # Check events and times try: event_times = exp_data['events']['timestamp'] except: print(f"Entry {subject} {date} {te} is missing event timestamps (likely missing ecube data).") print('source files:', exp_metadata['source_files']) bad_entries.append([subject,date,te]) continue # Check that events are accurate if not np.array_equal(exp_data['events']['code'], exp_data['bmi3d_events']['code']): print(f"Entry {subject} {date} {te} was excluded due to mismatched sync and bmi3d events (this will likely cause problems).") bad_entries.append([subject,date,te]) return bad_entries
# python implementation of badChannelDetection.m - see which channels are too noisy
[docs]def bad_channel_detection(data, srate, num_th=3., lf_c=100., sg_win_t=8., sg_over_t=4., sg_bw = 0.5): """bad_channel_detection Checks input [nt, nch] data array channel quality Args: data (nt, nch): numpy array of data srate (int): sample rate num_th (float, optional): Constant to adjust threshold. Defaults to 3. lf_c (int, optional): low frequency cutoff. Defaults to 100. sg_win_t (numeric, optional): spectrogram window length. Defaults to 8. sg_over_t (numeric, optional): spectrogram overlap length. Defaults to 4. sg_bw (float, optional): spectrogram time-half-bandwidth product. Defaults to 0.5. Returns: bad_ch (nch): logical array indicating bad channels """ sg_step_t = sg_win_t - sg_over_t assert sg_step_t > 0, 'window length must be greater than window overlap' print("Running bad channel assessment:") # compute low-freq PSD estimate n, p, k = precondition.convert_taper_parameters(sg_win_t, sg_bw) fxx, txx, Sxx_low = analysis.calc_mt_tfr(data, n, p, k, srate, step=sg_step_t, fk=lf_c) Sxx_low_psd = np.mean(Sxx_low,axis=1) psd_var = np.var(Sxx_low_psd,axis=0) norm_psd_var = psd_var/npla.norm(psd_var) low_var_θ = np.mean(norm_psd_var)/num_th bad_ch = norm_psd_var >= low_var_θ return bad_ch
[docs]def detect_bad_ch_outliers(data, nbins=10000, thr=0.05, numsd=5.0, debug=False, verbose=True): ''' Detect badchannels. This code is originally Amy's code in pesaran lab. This function has 2 steps. In the 1st step, it detects tentative bad channels whose SD is less than th or more than 1-th in CDF In the 2nd step, it extracts bad channels from tentative bad channels by finding conditions where (SD - median of SD) is outside numsd*SD[~badch] Args: data (nt, nch): neural data nbins (int): number of bins to make CDF thr (float, optional): threshold in the CDF to detect bad channels for 1st screening. 0.05 means data outisde 5% or 95% in the CDF is regarded as bad channels. numsd (float, optional): number of standard deviations above zero to detect bad channels for 2nd screening debug (bool, optional): if True, display a figure showing the threshold crossings verbose (bool, optional): if True, print bad channels Returns: bad_ch (nch) : logical array indicating bad channels after 2nd screening Example: .. code-block:: python test_data = np.random.normal(10,0.5,(10000, 200)) test_data[0, 10] = 25 test_data[5, 150] = 30 bad_ch = quality.detect_bad_ch_outliers(test_data, nbins=10000, thr=0.05, numsd=5.0, debug=True, verbose=False) .. image:: _images/detect_bad_ch_outliers.png ''' assert thr > 0 and thr < 1, "Threshold must be between 0 and 1" assert numsd > 0, "numsd must be more than 0" sd = np.std(data, axis=0) nch = data.shape[1] med_sd = np.median(sd) # 1st screening hist, bins = np.histogram(sd, nbins) CDF = np.cumsum(hist)/np.sum(hist) bottom = bins[np.where(CDF<thr)[0][-1]] top = bins[np.where(CDF>1-thr)[0][0]] bad_ch1 = (sd < bottom) | (sd > top) # 2nd screening thr2 = numsd*np.std(sd[~bad_ch1], ddof=1) sd_from_med = np.abs(sd - med_sd) inrange = sd_from_med <= thr2 bad_ch = bad_ch1 & ~inrange if debug: fig,ax = plt.subplots(1,2,figsize=(11,4),tight_layout=True) ax[0].plot(bins[:-1],CDF) ax[0].scatter(bins[np.where(CDF<thr)[0]],CDF[CDF<thr],c='g',s=1,label='bad_ch (1st screening)') ax[0].scatter(bins[np.where(CDF>1-thr)[0]],CDF[CDF>1-thr],c='g',s=1) ax[0].plot([top,top],[0,1],'r--', label='threshold (1st screening)') ax[0].plot([bottom,bottom],[0,1],'r--') ax[0].set(xlabel='SD',ylabel='CDF',title='1st screening') ax[0].legend() ax[1].plot(sd_from_med,'.') ax[1].plot([0,nch],[thr2,thr2],'r--', label='threshold (2nd screening)') ax[1].plot(np.where(bad_ch1)[0], sd_from_med[bad_ch1], 'g.',label='bad_ch (1st screening)') ax[1].plot(np.where(bad_ch)[0], sd_from_med[bad_ch], 'r*', label='bad_ch (2nd screening)') ax[1].set(ylabel='abs(SD-median of SD)',xlabel='# channels',title='2nd screening') ax[1].legend() if verbose: print(f'Bad channels : {np.where(bad_ch)[0]}') print(f'The number of bad channels : {np.sum(bad_ch)}') return bad_ch
[docs]def detect_bad_trials(erp, sd_thr=5, ch_frac=0.5, debug=False): ''' Finds trials where a given fraction of channels contain outlier data. Args: erp (nt, nch, ntr): trial-aligned continuous data sd_thr (float, optional): number of standard deviations away from the mean to threshold bad data. Default 5 ch_frac (float, optional): fraction (between 0. and 1.) of channels containing bad data to consider a trial as bad. Default 0.5 debug (bool, optional): if True, display a figure showing the threshold crossings Returns: (ntr,) boolean mask: True for bad trials, False for good trials Example: .. code-block:: python nt = 50 nch = 10 ntr = 100 erp = np.random.normal(size=(nt, nch, ntr)) erp[:,:,0] += 10 # entire trial is noisy across all electrodes erp[:,:8,1] -= 10 # entire trial is noisy on most electrodes erp[0,:,2] += 10 # single timepoint within the trial is noisy on all electrodes for t in range(nt): erp[t,t%nch,3] -= 10 # single timepoint is noisy but different timepoint for each channel bad_trials = detect_bad_trials(erp, sd_thr=5, ch_frac=0.5, debug=True) .. image:: _images/detect_bad_trials.png ''' assert erp.ndim == 3 nt, nch, ntr = erp.shape median = np.nanmedian(erp, axis=(0,2), keepdims=True) sd = np.nanstd(erp, axis=(0,2), keepdims=True) bad_timepoints = abs(erp - median) > sd_thr*sd bad_ch_trials = np.any(bad_timepoints, axis=0) # (nch, ntr) mask where any timepoint is bad bad_trials = np.sum(bad_ch_trials, axis=0) > ch_frac * nch # trials where most channels have an outlier if debug: # Highlight bad timepoints across trials plt.figure(figsize=(11,4), layout='compressed') plt.subplot(1,2,1) erp = abs(erp - median)/sd erp = np.nanmax(erp, axis=0) erp[bad_ch_trials] = np.nan trials = np.arange(ntr) cmap = copy.copy(plt.get_cmap('viridis')) cmap.set_bad(color='w') # set the 'bad' color to white im = visualization.plot_image_by_time(trials, erp.T, ylabel='channel', cmap=cmap) cbar = plt.colorbar(im) cbar.set_label('sd') plt.xlabel('trial') plt.title('sd over threshold shown in white') # Plot number of bad channels for each trial plt.subplot(1,2,2) ch = np.sum(bad_ch_trials, axis=0) trial = np.arange(ntr) plt.scatter(trial[~bad_trials], ch[~bad_trials], marker='.', color='k', label='good trials') plt.scatter(trial[bad_trials], ch[bad_trials], marker='x', color='r', label='bad trials') plt.xlabel('trial') plt.ylabel('# channels') plt.hlines(ch_frac*nch, 0, ntr, linestyles='dashed', color='r') plt.title('fraction of channels above threshold') plt.legend() return bad_trials
[docs]def detect_bad_timepoints(data, sd_thr=5, ch_frac=0.5, debug=False): ''' Finds timepoints where a given fraction of channels contain outlier data. For best results, you may need to first compute the power of the data before using this function. Args: data (nt, nch): continuous data sd_thr (float, optional): number of standard deviations away from the mean to threshold bad data. Default 5 ch_frac (float, optional): fraction (between 0. and 1.) of channels containing bad data to consider a trial as bad. Default 0.5 debug (bool, optional): if True, display a figure showing the threshold crossings Returns: (nt,) boolean mask: True for bad timepoints, False for good timepoints Example: .. code-block:: python nt = 200 nch = 10 np.random.seed(0) data = np.random.normal(size=(nt, nch)) data[0:50,:] += 10 # timepoint is noisy across all electrodes data[50:100,8:] -= 10 # timepoint is noisy on most electrodes for t in range(100,nt): data[t,t%nch] -= 10 # single timepoint is noisy but different timepoint for each channel bad_timepoints = quality.detect_bad_timepoints(data, sd_thr=5, ch_frac=0.5, debug=True) .. image:: _images/detect_bad_timepoints.png ''' assert data.ndim == 2 nt, nch = data.shape median = np.nanmedian(data, axis=0, keepdims=True) sd = np.nanstd(data, axis=0, keepdims=True) bad_timepoints = abs(data - median) > sd_thr*sd bad_ch_timepoints = np.sum(bad_timepoints, axis=1) > ch_frac * nch # trials where most channels have an outlier if debug: # Highlight bad timepoints across trials plt.figure(figsize=(11,4), layout='compressed') plt.subplot(1,2,1) data = abs(data - median)/sd data[bad_timepoints] = np.nan time = np.arange(nt) cmap = copy.copy(plt.get_cmap('viridis')) cmap.set_bad(color='w') # set the 'bad' color to white im = visualization.plot_image_by_time(time, data, ylabel='channel', cmap=cmap) cbar = plt.colorbar(im) cbar.set_label('sd') plt.xlabel('timepoint') plt.title('sd over threshold shown in white') # Plot number of bad channels for each trial plt.subplot(1,2,2) ch = np.sum(bad_timepoints, axis=1) time = np.arange(nt) time[~bad_ch_timepoints] ch[~bad_ch_timepoints] plt.scatter(time[~bad_ch_timepoints], ch[~bad_ch_timepoints], marker='.', color='k', label='good timepoints') plt.scatter(time[bad_ch_timepoints], ch[bad_ch_timepoints], marker='x', color='r', label='bad timepoints') plt.xlabel('timepoint') plt.ylabel('# channels') plt.hlines(ch_frac*nch, 0, nt, linestyles='dashed', color='r') plt.title('fraction of channels above threshold') plt.legend() return bad_ch_timepoints
# python implementation of highFreqTimeDetection.m - looks for spectral signatures of junk data
[docs]def high_freq_data_detection(data, srate, bad_channels=None, lf_c=100., sg_win_t=8., sg_over_t=4., sg_bw=0.5): """high_freq_data_detection Checks multichannel numpy array data for excess high frequency power. Returns a logical array of time locations in which any channel has excess high power (indicates noise) Args: data (nt, nch): timerseries data across channels srate (numeric): data sampling rate bad_channels (boolean array, optional): Array-like of boolean values indicating bad channels. Defaults to None. lf_c (numeric, optional): low frequency cutoff. Defaults to 100. Returns: bad_data_mask (nt): boolean array indicating timepoints with detected high-frequency noise on any channel bad_data_mask_all_ch (nt, nch): boolean array indicating time points at which any channel had high-frequency noise """ print("Running high frequency noise detection: lfc @ {0}".format(lf_c)) [num_samp, num_ch] = np.shape(data) bad_data_mask_all_ch = np.zeros((num_samp, num_ch)) data_t = np.arange(num_samp)/srate if not bad_channels: bad_channels = np.zeros(num_ch) # estimate hf influence, channel-wise for ch_i in np.arange(num_ch)[np.logical_not(bad_channels)]: utils.print_progress_bar(ch_i,num_ch) sg_step_t = sg_win_t - sg_over_t assert sg_step_t > 0, 'window length must be greater than window overlap' n, p, k = precondition.convert_taper_parameters(sg_win_t, sg_bw) fxx, txx, Sxx = analysis.calc_mt_tfr(data[:, ch_i], n, p, k, srate, step=sg_step_t) num_freq, = np.shape(fxx) num_t, = np.shape(txx) Sxx_mean = np.mean(Sxx, axis=1) # average across all windows, i.e. numch x num_f periodogram # get low-freq, high-freq data low_f_mask = fxx < lf_c # Hz high_f_mask = np.logical_not(low_f_mask) low_f_mean = np.mean(Sxx_mean[low_f_mask],axis=0) low_f_std = np.std(Sxx_mean[low_f_mask],axis=0) high_f_mean = np.mean(Sxx_mean[high_f_mask],axis=0) high_f_std = np.std(Sxx_mean[high_f_mask],axis=0) # set thresholds for high, low freq. data low_θ = low_f_mean - 3*low_f_std high_θ = high_f_mean + 3*high_f_std for t_i, t_center in enumerate(txx): low_f_mean_ = np.mean(Sxx[low_f_mask,t_i]) high_f_mean_ = np.mean(Sxx[high_f_mask,t_i]) if low_f_mean_ < low_θ or high_f_mean_ > high_θ: # get indeces for the given sgram window and set them to "bad:True" t_bad_mask = np.logical_and(data_t > t_center - sg_win_t/2, data_t < t_center + sg_win_t/2) bad_data_mask_all_ch[t_bad_mask, ch_i] = True # bad_ch_θ = 0 # bad_data_mask = np.sum(bad_data_mask_all_ch,axis=0) > bad_ch_θ bad_data_mask = np.any(bad_data_mask_all_ch,axis=1) return bad_data_mask, bad_data_mask_all_ch
# py version of noiseByHistogram.m - get upper and lower signal value bounds from a histogram
[docs]def histogram_defined_noise_levels(data, nbin=20): """histogram_defined_noise_levels Automatically determine bandwidth in a signal Args: data (np.array): single-channel data array nbin (int, optional): number of histogram bins. Defaults to 20. Returns: noise_bounds (tuple): lower, upper bound values """ # remove data in outer bins of the histogram calculation hist, bin_edge = np.histogram(data,bins=nbin) low_edge, high_edge = bin_edge[1], bin_edge[-2] no_edge_mask = np.all([(data > low_edge), (data < high_edge)],axis = 0) data_no_edge = data[no_edge_mask] # compute gaussian 99% CI estimate from trimmed data data_mean = np.mean(data) data_std = np.std(data) data_CI_lower, data_CI_higher = data_mean - 3*data_std, data_mean + 3*data_std # return min/max values from whole dataset or the edge values, whichever is lower noise_lower = low_edge if low_edge < data_CI_lower else min(data) noise_upper = high_edge if high_edge > data_CI_higher else max(data) return (noise_lower, noise_upper)
# py version of saturatedTimeDetection.m - get indeces of saturated data segments
[docs]def saturated_data_detection(data, srate, bad_channels=None, adapt_tol=1e-8 , win_n=20, verbose=True): """saturated_data_detection Detects saturated data segments in input data array Args: data (nt, nch): numpy array of multichannel data srate (numeric): data sampling rate bad_channels (bool array, optional): boolean array indicating bad data channels. Default: None adapt_tol (float, optional): detection tolerance. Default: 1e-8 win_n (int, optional): sample length of detection window. Default: 20 verbose (bool, optional): if True, print progress bar. Default: True Returns: sat_data_mask (nt): boolean array indicating saturated data detection bad_all_ch_mask (nt, nch): boolean array indicated separate channel saturation detected """ if verbose: print("Running saturated data segment detection:") num_samp, num_ch = np.shape(data) if not bad_channels: bad_channels = np.zeros(num_ch) bad_all_ch_mask = np.zeros((num_samp, num_ch)) data_rect = np.abs(np.float32(data)) mask = [bool(not x) for x in bad_channels] for ch_i in np.arange(num_ch)[mask]: if verbose: utils.print_progress_bar(ch_i, num_ch) ch_data = data_rect[:, ch_i] θ1 = 50 # initialize threshold value θ0 = 0 h, valc = np.histogram(ch_data, int(np.max(ch_data))) val = (valc[1:] + valc[:-1])/2 # computes the midpoints of each bin, valc are the edges val = np.floor(val) prob_val = h/np.shape(h)[0] # estimate midpoint between bimodal distribution for a theshold value while np.abs(θ1 - θ0) > adapt_tol: θ0 = θ1 sub_θ_val_mask = val <= θ1 sup_θ_val_mask = val > θ1 sub_θ_val_mean = np.sum(np.multiply(val[sub_θ_val_mask], prob_val[sub_θ_val_mask]))/np.sum(prob_val[sub_θ_val_mask]) sup_θ_val_mean = np.sum(np.multiply(val[np.logical_not(sup_θ_val_mask)], prob_val[np.logical_not(sup_θ_val_mask)]))/np.sum(prob_val[sup_θ_val_mask]) θ1 = (sub_θ_val_mean + sup_θ_val_mean)/2 # filter signal, boxcar window b_filt = np.ones(win_n)/win_n a_filt = 1 ch_data_filt = sps.lfilter(b_filt,a_filt,ch_data) ch_data_filt_sup_θ_mask = ch_data_filt > θ1 # get histogram-derived noise limits n_low, n_high = histogram_defined_noise_levels(ch_data) ch_data_low_mask = ch_data < n_low ch_data_high_mask = ch_data > n_high ch_data_filt_low_mask = np.logical_and(ch_data_filt_sup_θ_mask, ch_data_low_mask) ch_data_filt_high_mask = np.logical_and(ch_data_filt_sup_θ_mask, ch_data_high_mask) bad_all_ch_mask[:, ch_i] = np.logical_or(ch_data_filt_low_mask, ch_data_filt_high_mask) # clear out straggler values # I will hold off on implementing this until # out_of_range_samp_mask = np.logical_or(ch_data < n_low, ch_data > n_high) # for samp_i in np.arange(samp_i)[np.logical_and(out_of_range_samp_mask,np.logical_not(bad_all_ch_mask[i,:]))]: # if np.abs(ch_data[samp_i]) >= θ1 and # if samp_i < num_samp - srate*45: # else: num_bad = np.sum(bad_all_ch_mask,axis=1) sat_data_mask = num_bad > num_ch/2 return sat_data_mask, bad_all_ch_mask