Source code for aopy.visualization.base

# visualization.py
# 
# Code for general neural data plotting (raster plots, multi-channel field potential plots, psth, etc.)

import string
import warnings
from datetime import timedelta
import os
import copy
import sys

from matplotlib.markers import MarkerStyle
if sys.version_info >= (3,9):
    from importlib.resources import files, as_file
else:
    from importlib_resources import files, as_file

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib import colors
from matplotlib import cm
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from mpl_toolkits.axes_grid1 import ImageGrid
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
import seaborn as sns
from scipy.interpolate import griddata
from scipy.spatial import cKDTree
from scipy import signal
from scipy.stats import zscore
import numpy as np
from PIL import Image
import pandas as pd
from tqdm import tqdm

from .. import precondition
from .. import analysis
from .. import data as aodata
from .. import utils
from .. import preproc

[docs]def plot_mean_fr_per_target_direction(means_d, neuron_id, ax, color, this_alpha, this_label): ''' generate a plot of mean firing rate per target direction ''' sns.set_context('talk') ax.plot(np.array(means_d)[:, neuron_id], c=color, alpha=this_alpha, label=this_label) ax.legend() ax.set_xlabel("Target", fontsize=16) ax.set_ylabel("Spike Rate (Hz)", fontsize=16) plt.tight_layout()
[docs]def savefig(base_dir, filename, **kwargs): ''' Wrapper around matplotlib savefig with some default options Args: base_dir (str): where to put the figure filename (str): what to name the figure **kwargs (optional): arguments to pass to plt.savefig() ''' if '.' not in filename: filename += '.png' fname = os.path.join(base_dir, filename) if 'dpi' not in kwargs: kwargs['dpi'] = 300. if 'edgecolor' not in kwargs: kwargs['edgecolor'] = 'none' if 'transparent' not in kwargs: kwargs['transparent'] = True if kwargs['transparent'] and 'facecolor' not in kwargs: kwargs['facecolor'] = 'none' plt.savefig(fname, **kwargs)
[docs]def subplots_with_labels(n_rows, n_cols, return_labeled_axes=False, rel_label_x=-0.25, rel_label_y=1.1, label_font_size=11, constrained_layout=False, **kwargs): ''' Create a figure with subplots labeled with letters. Augments plt.subplots(). Examples: Generate a figure with 2 rows and 2 columns of subplots, labeled A, B, C, D .. code-block:: python fig, axes = subplots_with_labels(2, 2, constrained_layout=True) .. image:: _images/labeled_subplots.png Args: n_rows (int): Number of rows of subplots. n_cols (int): Number of columns of subplots. return_labeled_axes (bool, optional): Whether to return the labeled axes. Default False. rel_label_x (float, optional): The relative x position of the subplot label. Default -0.25. rel_label_y (float, optional): The relative y position of the subplot label. Default 1.1 label_font_size (int, optional): The font size of the subplot label. Default 11. constrained_layout (bool, optional): Whether to use constrained layout. Default is False. **kwargs: Additional keyword arguments to pass to plt.subplot_mosaic. Returns: fig (Figure): The created figure. axes (np.ndarray): The created axes. labels_axes (dict, optional): The labeled axes if return_labeled_axes is True. ''' # if more than 26 subplots, raise an error if n_rows * n_cols > 26: raise ValueError("More than 26 subplots requested, running out of single letters to label them with!") # make a list of letters to use as labels alphabets = string.ascii_uppercase labels = alphabets[:n_rows * n_cols] # tabulate the labels into n_rows by n_cols array labels = np.array(list(labels)).reshape((n_rows, n_cols)) # make a string where rows are separated by semicolons labels = ";".join(["".join(row) for row in labels]) # make the figure and axes fig, labels_axes = plt.subplot_mosaic(labels, constrained_layout=constrained_layout, **kwargs) for n, (key, ax) in enumerate(labels_axes.items()): ax.text(rel_label_x, rel_label_y, key, transform=ax.transAxes, size=label_font_size) # just annotate the axes axes = list(labels_axes.values()) axes = np.array(axes).reshape((n_rows, n_cols)) if return_labeled_axes: return fig, axes, labels_axes else: return fig, axes
[docs]def place_subplots(fig, positions, width, height, **kwargs): ''' Plotting utility to create subplots in arbitrary positions on a figure. Positions are in inches from the bottom left corner of the figure. Args: fig (pyplot.Figure): figure to place the subplots on positions (npos, 2): list of (x, y) coordinates (in inches) where to center the subplots width (float): width (in inches) of each subplot height (float): height (in inches) of each subplot kwargs (dict, optional): other keyword arguments to pass to fig.add_axes Returns: list: pyplot.Axes handles for each position Examples: .. code-block:: python fig = plt.figure(figsize=(4,6)) positions = [[1, 2], [3, 4]] width = 1 height = 1 ax = place_subplots(fig, positions, width, height) ax[0].annotate('1', (0.5,0.5), fontsize=40) ax[1].annotate('2', (0.5,0.5), fontsize=40) .. image:: _images/place_subplots_1.png .. code-block:: python fig = plt.figure(figsize=(4,6)) positions = [[1, 1.5], [3, 4.5]] width = 2 height = 3 ax = place_subplots(fig, positions, width, height) ax[0].annotate('1', (0.5,0.5), fontsize=40) ax[1].annotate('2', (0.5,0.5), fontsize=40) .. image:: _images/place_subplots_2.png ''' # Normalize the positions to fit into the size of the figure fig_width, fig_height = fig.get_size_inches() positions = np.array(positions, dtype='float') positions[:,0] = positions[:,0] / fig_width positions[:,1] = positions[:,1] / fig_height width /= fig_width height /= fig_height # Place subplots ax = [] for cx, cy in positions: left = cx - width/2 bottom = cy - height/2 ax.append(fig.add_axes([left, bottom, width, height], **kwargs)) return ax
[docs]def place_Opto32_subplots(fig_size=5, subplot_size=0.75, offset=(0.,-0.25), theta=0, **kwargs): ''' Wrapper around place_subplots() for the Opto32 stimulation sites. Args: fig_size (float): width and height (in inches) of the figure subplot_size (float): width and height (in inches) of each subplot offset (tuple): x and y offset (in inches) from the bottom left corner of the figure theta (float): rotation (in degrees) to apply to positions. kwargs (dict, optional): other keyword arguments to pass to fig.add_axes Returns: tuple: tuple containing: | **fig (pyplot.Figure):** figure where the subplots were placed | **ax (list):** pyplot.Axes handles for each stimulation site Examples: .. image:: _images/place_Opto32_subplots.png ''' stim_pos, _, _ = aodata.load_chmap('Opto32', theta=theta) # Normalize the positions to the width and height of the figure stim_pos = (stim_pos - np.mean(stim_pos, axis=0)) / (np.max(stim_pos) - np.min(stim_pos)) * fig_size + fig_size/2 # Place subplots fig = plt.figure(figsize=(fig_size,fig_size), **kwargs) ax = place_subplots(fig, stim_pos + np.array(offset), subplot_size, subplot_size) # Remove the axis labels for ax_ in ax: ax_.tick_params( which='both', bottom=False, left=False, labelbottom=False, labelleft=False ) return fig, ax
[docs]def plot_colorbar(size, cmap, clim=(0,1), orientation='vertical', ticks=None, label=None, labelpad=5, **kwargs): ''' Plot just a colorbar in its own figure for the given colormap and color limits. Args: size (2-tuple): (width, height) of the colorbar in inches cmap (str or Colormap): colormap to use clim (2-tuple, optional): color limits to use. Default (0,1) orientation (str, optional): 'vertical' or 'horizontal'. Default 'vertical' ax (pyplot.Axis, optional): axis to plot the colorbar on. Default None, which will use gca. kwargs (dict, optional): additional keyword arguments to pass to plt.colorbar() Returns: Colorbar: the created colorbar object Examples: .. code-block:: python aopy.visualization.plot_colorbar( size=(0.2,2), cmap='viridis', clim=(0, 100), ticks=[0,100], orientation='vertical', label='(0.2 x 2) in colorbar', labelpad=-15, ) .. image:: _images/colorbar_example_vertical.png .. code-block:: python aopy.visualization.plot_colorbar( size=(3,0.3), cmap='hsv', clim=(0, 1), orientation='horizontal', label='(3 x 0.3) in colorbar' ) .. image:: _images/colorbar_example_horizontal.png ''' if orientation == 'horizontal': fig, ax = plt.subplots(figsize=(1.1*size[0], 5*size[1])) plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.75) else: fig, ax = plt.subplots(figsize=(5*size[0],1.1*size[1])) plt.subplots_adjust(bottom=0.05, top=0.95, left=0.05, right=0.25) norm = colors.Normalize(vmin=clim[0], vmax=clim[1]) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) # only needed for older matplotlib versions to avoid a warning cbar = plt.colorbar(sm, cax=ax, orientation=orientation, ticks=ticks, **kwargs) if label is not None: cbar.set_label(label, labelpad=labelpad) return cbar
[docs]def plot_timeseries(data, samplerate, t0=0., ax=None, **kwargs): ''' Plots data along time on the given axis. Default units are seconds and volts. Example: Plot 50 and 100 Hz sine wave. .. code-block:: python data = np.reshape(np.sin(np.pi*np.arange(1000)/10) + np.sin(2*np.pi*np.arange(1000)/10), (1000)) samplerate = 1000 plot_timeseries(data, samplerate) .. image:: _images/timeseries.png Args: data (nt, nch): timeseries data in volts, can also be a single channel vector samplerate (float): sampling rate of the data t0 (float, optional): time (in seconds) of the first sample. Default 0. ax (pyplot axis, optional): where to plot kwargs (dict, optional): optional keyword arguments to pass to plt.plot ''' if np.ndim(data) < 2: data = np.expand_dims(data, 1) if ax is None: ax = plt.gca() time = np.arange(np.shape(data)[0]) / samplerate + t0 for ch in range(np.shape(data)[1]): ax.plot(time, data[:, ch], **kwargs) ax.set_xlabel('Time (s)') ax.set_ylabel('Voltage (V)')
[docs]def gradient_timeseries(data, samplerate, n_colors=100, color_palette='viridis', ax=None, **kwargs): ''' Draw gradient lines of timeseries data. Default units are seconds and volts. Args: data (nt, nch): timeseries to plot, can be 1d or 2d. samplerate (float): sampling rate of the data n_colors (int, optional): number of colors in the gradient. Default 100. color_palette (str, optional): colormap to use for the gradient. Default 'viridis'. ax (plt.Axis, optional): axis to plot the targets on kwargs (dict): keyword arguments to pass to the LineCollection function (similar to plt.plot) Raises: ValueError: if the data has more than 2 dimensions Example: .. code-block:: python data = np.reshape(np.sin(np.pi*np.arange(1000)/100), (1000)) samplerate = 1000 gradient_timeseries(data, samplerate) .. image:: _images/timeseries_gradient.png ''' if data.ndim == 1: data = np.expand_dims(data, 1) elif data.ndim > 2: raise ValueError('Data with more than 2 dimensions not supported!') if ax is None: ax = plt.gca() n_pt = data.shape[0] time = np.arange(n_pt) / samplerate colors = sns.color_palette(color_palette, min(n_colors, n_pt)) # Segment the line labels = np.zeros((n_pt,), dtype='int') size = (n_pt // n_colors) * n_colors # largest size we can evenly split into n_colors labels[:size] = np.repeat(range(n_colors), n_pt // n_colors) labels[size:] = n_colors - 1 # leftovers also get the last color times, _ = utils.segment_array(time, labels, duplicate_endpoints=True) lines, line_labels = utils.segment_array(data, labels, duplicate_endpoints=True) # Use linecollections to plot each channel of data labels = np.array(line_labels).astype(int) colors = [colors[i] for i in labels] for dim in range(data.shape[1]): segments = [np.vstack([t, l[:,dim]]).T for t, l in zip(times, lines)] lc = LineCollection(segments, colors=colors, **kwargs) ax.add_collection(lc) ax.margins(0.05) # add_collections doesn't autoscale ax.set_xlabel('Time (s)') ax.set_ylabel('Voltage (V)')
[docs]def plot_freq_domain_amplitude(data, samplerate, ax=None, rms=False): ''' Plots a amplitude spectrum of each channel on the given axis. Just need to input time series data and this will calculate and plot the amplitude spectrum. Example: Plot 50 and 100 Hz sine wave amplitude spectrum. .. code-block:: python data = np.sin(np.pi*np.arange(1000)/10) + np.sin(2*np.pi*np.arange(1000)/10) samplerate = 1000 plot_freq_domain_amplitude(data, samplerate) # Expect 100 and 50 Hz peaks at 1 V each .. image:: _images/freqdomain.png Args: data (nt, nch): timeseries data in volts, can also be a single channel vector samplerate (float): sampling rate of the data ax (pyplot axis, optional): where to plot rms (bool, optional): compute root-mean square amplitude instead of peak amplitude ''' if ax is None: ax = plt.gca() non_negative_freq, data_ampl = analysis.calc_freq_domain_amplitude(data, samplerate, rms) for ch in range(np.shape(data_ampl)[1]): ax.semilogx(non_negative_freq, data_ampl[:,ch]) ax.set_xlabel('Frequency (Hz)') if rms: ax.set_ylabel('RMS amplitude (V)') else: ax.set_ylabel('Peak amplitude (V)')
[docs]def get_data_map(data, x_pos, y_pos): ''' Organizes data according to the given x and y positions Args: data (nch): list of values x_pos (nch): list of x positions y_pos (nch): list of y positions Returns: (m,n array): map of the data on the grid defined by x_pos and y_pos ''' data = np.reshape(data, -1) x_pos = np.round(x_pos, 9) # avoid floating point errors y_pos = np.round(y_pos, 9) X = np.unique(x_pos) Y = np.unique(y_pos) nX = len(X) nY = len(Y) # Order Y into rows and X into columns data_map = np.empty((nY, nX), dtype=data.dtype) data_map[:] = np.nan for data_idx in range(len(data)): xid = np.where(X == x_pos[data_idx])[0] yid = np.where(Y == y_pos[data_idx])[0] data_map[yid, xid] = data[data_idx] return data_map
[docs]def calc_data_map(data, x_pos, y_pos, grid_size, interp_method='nearest', threshold_dist=None, extent=None): ''' Turns scatter data into grid data by interpolating up to a given threshold distance. Args: data (nch): list of values x_pos (nch): list of x positions y_pos (nch): list of y positions grid_size (2-tuple): number of points along each axis (width, height) interp_method (str): method used for interpolation threshold_dist (float): distance to neighbors before disregarding a point on the image extent (list): [xmin, xmax, ymin, ymax] to define the extent of the interpolated grid. Default None, which will use the min and max of the x and y positions. Returns: tuple: tuple containing: | *data_map (grid_size array, e.g. (16,16)):* map of the data on the given grid | *xy (grid_size array, e.g. (16,16)):* new grid positions to use with this map Example: Make a plot of a 10 x 10 grid of increasing values with some missing data. .. code-block:: python data = np.linspace(-1, 1, 100) x_pos, y_pos = np.meshgrid(np.arange(0.5,10.5),np.arange(0.5, 10.5)) missing = [0, 5, 25] data_missing = np.delete(data, missing) x_missing = np.reshape(np.delete(x_pos, missing),-1) y_missing = np.reshape(np.delete(y_pos, missing),-1) data_map = get_data_map(data_missing, x_missing, y_missing) plt.figure() plot_spatial_map(data_map, x_missing, y_missing) .. image:: _images/posmap.png Use `calc_data_map` to interpolate the missing data .. code-block:: python interp_map, xy = calc_data_map(data_missing, x_missing, y_missing, [10, 10], threshold_dist=1.5) plot_spatial_map(interp_map, xy[0], xy[1]) .. image:: _images/posmap_calcmap.png Use cubic interpolation to generate a high resolution map .. code-block:: python interp_map, xy = calc_data_map(data_missing, x_missing, y_missing, [100, 100], threshold_dist=1.5, interp_method='cubic') plt.figure() plot_spatial_map(interp_map, xy[0], xy[1]) .. image:: _images/posmap_calcmap_interp.png ''' if extent is None: extent = [np.min(x_pos), np.max(x_pos), np.min(y_pos), np.max(y_pos)] if len(x_pos) != len(y_pos): raise ValueError('x_pos and y_pos must have the same length!') if len(data) != len(x_pos): raise ValueError('Data and position must have the same length!') data = np.squeeze(data) x_spacing = (extent[1] - extent[0]) / (grid_size[0] - 1) y_spacing = (extent[3] - extent[2]) / (grid_size[1] - 1) xy = np.vstack((x_pos, y_pos)).T xq, yq = np.meshgrid(np.arange(extent[0], extent[0] + x_spacing * grid_size[0], x_spacing), np.arange(extent[2], extent[2] + y_spacing * grid_size[1], y_spacing)) # Remove nan values non_nan = np.logical_not(np.isnan(data)) data = data[non_nan] xy = xy[non_nan] # Interpolate new_xy = (np.reshape(xq, -1), np.reshape(yq, -1)) X = griddata(xy, data, new_xy, method=interp_method, rescale=False) # Construct kd-tree, functionality copied from scipy.interpolate tree = cKDTree(xy) xi = np.column_stack((np.reshape(xq, -1), np.reshape(yq, -1))) dists, indexes = tree.query(xi) # Mask values with distances over the threshold with NaNs if threshold_dist: X[dists > threshold_dist] = np.nan data_map = np.reshape(X, grid_size) return data_map, new_xy
[docs]def plot_spatial_map(data_map, x, y, alpha_map=None, ax=None, cmap='bwr', nan_color='black', clim=None): ''' Wrapper around plt.imshow for spatial data Args: data_map ((n,m) array): map of x,y data where n is the number of y positions and m is the number of x positions. Can also supply a (n,m,3) rgb or (n,m,4) rgba image to add color or transparency. x ((m,) array): list of x positions y ((n,) array): list of y positions alpha_map ((n,m) array): map of alpha values (optional, default alpha=1 everywhere). If the alpha values are outside of the range (0,1) they will be scaled automatically. ax (int, optional): axis on which to plot, default gca cmap (str, optional): matplotlib colormap to use in image. Default 'bwr'. This parameter is ignored if data_map is an rgb or rgba image. nan_color (str, optional): color to plot nan values, or None to leave them invisible. default 'black' clim ((2,) tuple): (min, max) to set the c axis limits. default None, show the whole range Returns: mappable: image object which you can use to add colorbar, etc. Examples: Make a plot of a 10 x 10 grid of increasing values with some missing data. .. code-block:: python data = np.linspace(-1, 1, 100) x_pos, y_pos = np.meshgrid(np.arange(0.5,10.5),np.arange(0.5, 10.5)) missing = [0, 5, 25] data_missing = np.delete(data, missing) x_missing = np.reshape(np.delete(x_pos, missing),-1) y_missing = np.reshape(np.delete(y_pos, missing),-1) data_map = get_data_map(data_missing, x_missing, y_missing) plot_spatial_map(data_map, x_missing, y_missing) .. image:: _images/posmap.png Make the same image but include a transparency layer .. code-block:: python data = np.linspace(-1, 1, 100) x_pos, y_pos = np.meshgrid(np.arange(0.5,10.5),np.arange(0.5, 10.5)) missing = [0, 5, 25] data_missing = np.delete(data, missing) x_missing = np.reshape(np.delete(x_pos, missing),-1) y_missing = np.reshape(np.delete(y_pos, missing),-1) data_map = get_data_map(data_missing, x_missing, y_missing) plot_spatial_map(data_map, x_missing, y_missing, alpha_map=data_map) .. image:: _images/posmap_alphamap.png ''' # Calculate the proper extents assert np.ndim(data_map) == 2 or np.ndim(data_map) == 3, 'data_map must be 2D or 3D' if np.size(data_map) > 1: extent = [np.min(x), np.max(x), np.min(y), np.max(y)] x_spacing = (extent[1] - extent[0]) / (np.shape(data_map)[1] - 1) y_spacing = (extent[3] - extent[2]) / (np.shape(data_map)[0] - 1) extent = np.add(extent, [-x_spacing / 2, x_spacing / 2, -y_spacing / 2, y_spacing / 2]) else: extent = [np.min(x) - 0.5, np.max(x) + 0.5, np.min(y) - 0.5, np.max(y) + 0.5] # Set the 'bad' color to something different cmap = copy.copy(plt.get_cmap(cmap)) if nan_color: cmap.set_bad(color=nan_color) # If an alpha map is present, make an rgba image if alpha_map is not None: if clim is None: clim = (np.nanmin(data_map), np.nanmax(data_map)) norm = cm.colors.Normalize(*clim) scalarMap = cm.ScalarMappable(norm=norm, cmap=cmap) data_map = scalarMap.to_rgba(data_map) # Apply the alpha map after scaling from 0 to 1 alpha_range = np.nanmax(alpha_map) - np.nanmin(alpha_map) if alpha_range > 1 or np.nanmax(alpha_map) > 1 or np.nanmin(alpha_map) < 0: alpha_map = (alpha_map - np.nanmin(alpha_map)) / alpha_range alpha_map[np.isnan(alpha_map)] = 0 data_map[:,:,3] = alpha_map # Plot if ax is None: ax = plt.gca() image = ax.imshow(data_map, cmap=cmap, origin='lower', extent=extent) ax.set_xlabel('x position') ax.set_ylabel('y position') return image
[docs]def plot_spatial_drive_map(data, elec_data=False, drive_type='ECoG244', interp=True, grid_size=(16,16), cmap='bwr', theta=0, ax=None, **kwargs): ''' Plot a 2D spatial map of data from a spatial electrode array. Args: data ((nch,) array): values from the spatial drive to plot in 2D elec_data (bool, optional): if True, treat data as electrode data (i.e. nch == nelec), otherwise treat it as acquisition data (nch >= nelec). Defaults to False. interp (bool, optional): flag to include 2D interpolation of the result. Defaults to True. drive_type (str, optional): type of drive. See :func:`~aopy.data.load_chmap` for options. Defaults to 'ECoG244'. interp (bool, optional): flag to include 2D interpolation of the result. See :func:`~aopy.visualization.calc_data_map` for options. Defaults to True. grid_size ((2,) tuple, optional): size of the grid to interpolate to. Defaults to (16,16). cmap (str, optional): matplotlib colormap to use in image. Defaults to 'bwr'. theta (float): rotation (in degrees) to apply to positions. rotations are applied clockwise, e.g., theta = 90 rotates the map clockwise by 90 degrees, -90 rotates the map anti-clockwise by 90 degrees. Default 0. ax (pyplot.Axes, optional): axis on which to plot. Defaults to None. kwargs (dict): dictionary of additional keyword argument pairs to send to calc_data_map and plot_spatial_map. Returns: pyplot.Image: image returned by pyplot.imshow. Use to add colorbar, etc. Updated in v0.9.1 - removed bad_elec argument, added elec_data argument ''' if ax is None: ax = plt.gca() # Load the signal path files elec_pos, acq_ch, elecs = aodata.load_chmap(drive_type=drive_type, theta=theta) if not elec_data: data = data[acq_ch-1] # Interpolate or directly compute the map if interp: interp_kwargs = {k: v for k, v in kwargs.items() if k in ['interp_method', 'threshold_dist']} data_map, xy = calc_data_map(data, elec_pos[:,0], elec_pos[:,1], grid_size, **interp_kwargs) else: data_map = get_data_map(data, elec_pos[:,0], elec_pos[:,1]) xy = [elec_pos[:,0], elec_pos[:,1]] # Plot plot_kwargs = {k: v for k, v in kwargs.items() if k in ['alpha_map', 'nan_color', 'clim']} im = plot_spatial_map(data_map, xy[0], xy[1], cmap=cmap, ax=ax, **plot_kwargs) return im
[docs]def plot_ECoG244_data_map(data, elec_data=False, interp=True, cmap='bwr', theta=0, ax=None, **kwargs): ''' Plot a spatial map of data from an ECoG244 electrode array from the Viventi lab. Args: data ((256,) array): values from the ECoG array to plot in 2D elec_data (bool, optional): if True, treat data as electrode data (i.e. nch == nelec), otherwise treat it as acquisition data (nch >= nelec). Defaults to False. interp (bool, optional): flag to include 2D interpolation of the result. See :func:`~aopy.visualization.calc_data_map` for options. Defaults to True. cmap (str, optional): matplotlib colormap to use in image. Defaults to 'bwr'. theta (float): rotation (in degrees) to apply to positions. rotations are applied clockwise, e.g., theta = 90 rotates the map clockwise by 90 degrees, -90 rotates the map anti-clockwise by 90 degrees. Default 0. ax (pyplot.Axes, optional): axis on which to plot. Defaults to None. kwargs (dict): dictionary of additional keyword argument pairs to send to calc_data_map and plot_spatial_map. Returns: pyplot.Image: image returned by pyplot.imshow. Use to add colorbar, etc. Updated in v0.9.1 - removed bad_elec argument, added elec_data argument Examples: .. code-block:: python data = np.linspace(-1, 1, 256) missing = [0, 5, 25] missing_ch = acq_ch[np.isin(elecs, missing)]-1 data[missing_ch] = np.nan plt.figure() plot_ECoG244_data_map(data, interp=False, cmap='bwr', ax=None) # Here the missing electrodes (in addition to the ones # undefined by the channel mapping) should be visible in the map. plt.figure() plot_ECoG244_data_map(data, interp=False, cmap='bwr', ax=None, nan_color=None) # Now we make the missing electrodes transparent plt.figure() plot_ECoG244_data_map(data, interp=True, cmap='bwr', ax=None) # Missing electrodes should be filled in with linear interp. ''' return plot_spatial_drive_map(data, elec_data=elec_data, interp=interp, grid_size=(16,16), drive_type='ECoG244', cmap=cmap, theta=theta, ax=ax, **kwargs)
[docs]def plot_spatial_drive_maps(maps, nrows_ncols, axsize, clim=None, theta=0, axes_pad=0.05, label_mode="1", cbar_mode=None, **kwargs): ''' Plot multiple spatial maps on the same figure. Uses mpl_toolkits.axes_grid1.ImageGrid to create a grid of axes. Args: maps (list): list of (nch,) list of values recorded from a spatial drive (e.g. electrode array) to plot nrows_ncols ((2,) tuple): number of rows and columns of subplots axsize ((2,) tuple): (width, height) size of each subplot in inches clim ((2,) tuple, optional): (min, max) to set the color axis limits. Default None, show the whole range, each image will be scaled independently. theta (float or list of, optional): rotation (in degrees) to apply to positions. rotations are applied clockwise, e.g., theta = 90 rotates the map clockwise by 90 degrees, -90 rotates the map anti-clockwise by 90 degrees. Default 0. If a list is given, it must be the same length as maps and each map will be rotated by the corresponding theta value. axes_pad (float, optional): padding between axes. Default 0.1 label_mode (str, optional): label mode for ImageGrid {"L", "1", "all", None}. Default None. cbar_mode (str, optional): colorbar mode for ImageGrid {"each", "single", None}. Default None. **kwargs: additional keyword arguments to pass to :func:`~aopy.visualization.plot_spatial_drive_map` Returns: tuple: tuple containing: - **fig (pyplot.Figure):** the created figure - **axes (np.ndarray):** the created axes returned by ImageGrid - **ims (list):** list of image handles - **cbars (list):** list of colorbar handles Examples: Create some test maps (ECoG244, ECoG244 flipped, random, random flipped) and plot them in different configurations. First, plot them in a 1x4 grid with a single colorbar. .. code-block:: python im1 = np.arange(256).astype(float) im2 = np.flip(im1) im3 = im1.copy() np.random.shuffle(im3) im4 = np.flip(im3) maps = [im1, im2, im3, im4] plot_spatial_drive_maps(maps, (1,4), (2,2), cmap='viridis', clim=(0,255), label_mode="L") plt.tight_layout() .. image:: _images/spatial_drive_maps_1_4.png Now plot them in a 2x2 grid with a single colorbar. .. code-block:: python plot_spatial_drive_maps(maps, (2,2), (2,2), cmap='viridis', clim=(0,255), cbar_mode='single') plt.tight_layout() .. image:: _images/spatial_drive_maps_2_2_single_cbar.png Last plot them in a 2x2 grid with a colorbar for each map. We need to change the horizontal spacing to make the colorbars fit. We can also make adjustmests after plotting using the returned axes. .. code-block:: python fig, axes, ims, cbars = plot_spatial_drive_maps(maps, (2,2), (2,2), cmap='viridis', clim=(0,255), label_mode=None, cbar_mode='each', axes_pad=(0.4,0.05)) axes[0].set_clim(127,255) plt.tight_layout() .. image:: _images/spatial_drive_maps_2_2.png ''' n_maps = len(maps) # Create a grid of axes fig = plt.figure(figsize=(axsize[0] * nrows_ncols[1], axsize[1] * nrows_ncols[0])) axes = ImageGrid(fig, 111, nrows_ncols=nrows_ncols, axes_pad=axes_pad, label_mode="1" if label_mode is None else label_mode, cbar_mode=cbar_mode, cbar_pad=0.05) # Plot each map using the existing function ims = [] cbars = [] for n in range(n_maps): if np.count_nonzero(~np.isnan(maps[n])) == 0: ims.append(None) continue ax = axes.axes_all[n] if isinstance(theta, (list, np.ndarray)): kwargs['theta'] = theta[n] else: kwargs['theta'] = theta im = plot_spatial_drive_map(maps[n], ax=ax, **kwargs) if label_mode is None: ax.set(xticks=[], yticks=[], xticklabels=[], yticklabels=[], xlabel='', ylabel='') if cbar_mode == 'each': cbars.append(ax.cax.colorbar(im)) if clim is not None: im.set_clim(clim) ims.append(im) if cbar_mode == 'single': cbars.append(axes.cbar_axes[0].colorbar(ims[0])) return fig, axes, ims, cbars
[docs]def annotate_spatial_map(elec_pos, text, color, annotation_style='text', fontsize=6, marker='o', markersize=0.25, ax=None, **kwargs): ''' Add either a text or marker annotation to a 2d position. Args: elec_pos ((x,y) tuple): position where text should be placed on 2d plot text (str): annotation text color (plt.Color): the color to make the text annotation_style (str, optional): style of annotation to use for stimulation site ['text', 'marker']. Default 'text'. fontsize (int, optional): the fontsize to make the text or marker. Defaults to 6. marker (str, optional): marker style for annotations if annotation_style is 'marker'. Options are the same as pyplot.markers.MarkerStyle; e.g. 'o', 's', etc. Default 'o'. markersize (float, optional): size of the marker in data units if annotation_style is 'marker'. Defaults to 0.25. ax (pyplot.Axes, optional): axis on which to plot. Defaults to None. kwargs (dict): additional keyword arguments to pass to plt.annotate() Returns: plt.Annotation: annotation object ''' if ax is None: ax = plt.gca() if annotation_style == 'text': return ax.annotate(text, elec_pos, color=color, fontsize=fontsize, ha='center', va='center', **kwargs) elif annotation_style == 'marker': scale = ax.transData.get_matrix()[0,0] # data units (x-axis) to display units (pixels) scale *= 72.0 / ax.get_figure().dpi # convert pixels to points marker_obj = MarkerStyle(marker) marker_width = marker_obj.get_path().transformed(marker_obj.get_transform()).get_extents().width return ax.plot(elec_pos[0], elec_pos[1], marker=marker, color=color, markersize=markersize*scale/marker_width/2, **kwargs)[0] else: raise ValueError("annotation_style must be either 'text' or 'marker'.")
[docs]def annotate_spatial_map_channels(acq_idx=None, acq_ch=None, drive_type='ECoG244', theta=0, color='k', annotation_style='text', fontsize=6, marker='o', markersize=0.25, ax=None, **kwargs): ''' Given acq_idx (indices) or acq_ch (channel numbers), prints either indices or channel numbers on top of a spatial map. Args: acq_idx ((nacq,) array or list, optional): If provided, specifies the acquisition indices to be annotated. If neither acq_idx nor acq_ch is provided, all channel numbers will be annotated by default. acq_ch ((nacq,) array or list, optional): If provided, specifies the acquisition channel numbers to be annotated. If neither acq_idx nor acq_ch is provided, all channel numbers will be annotated by default. drive_type (str, optional): Drive type of the channels to plot. See :func:`aopy.data.base.load_chmap`. color (str, optional): color to display the channels. Default 'k'. annotation_style (str, optional): style of annotation to use for stimulation site ['text', 'marker']. Default 'text'. fontsize (int, optional): the fontsize to make the text or marker. Defaults to 6. marker (str, optional): marker style for annotations if annotation_style is 'marker'. Options are the same as pyplot.markers.MarkerStyle; e.g. 'o', 's', etc. Default 'o'. markersize (float, optional): size of the marker in data units if annotation_style is 'marker'. Defaults to 0.25. print_zero_index (bool, optional): if True (the default), prints channel numbers indexed by 0. Otherwise prints directly from the channel map (which should use 1-indexing). ax (pyplot.Axes, optional): axis on which to plot. Defaults to None. kwargs (dict): additional keyword arguments to pass to plt.annotate() Example: .. code-block:: python aopy.visualization.plot_ECoG244_data_map(np.zeros(256,), cmap='Greys') aopy.visualization.annotate_spatial_map_channels(drive_type='ECoG244', color='k') aopy.visualization.annotate_spatial_map_channels(drive_type='Opto32', color='b', annotation_style='marker') plt.axis('off') .. image:: _images/ecog244_opto32.png Note: The acq_ch returned from `func::aopy.data.load_chmap` are generally 1-indexed lists of acquisition channels connected to electrodes. In python, however, the acquisition indices start at 0, so we give the option to select channels based on either an index (acq_idx) or a channel number (acq_ch). ''' if ax is None: ax = plt.gca() if acq_idx is not None and acq_ch is not None: raise ValueError("Please specify only one of acq_idx or acq_ch.") if acq_idx is not None: acq_ch = np.array(acq_idx)+1 # Change from index to ch numbers # Get channel map (overwrite acq_ch if it was supplied to get the correct shape acq_ch) elec_pos, acq_ch, elecs = aodata.load_chmap(drive_type, acq_ch, theta) # Annotate each channel if isinstance(color, str) or len(color) < len(elec_pos): color = np.repeat(np.array(color), len(elec_pos)) for pos, ch, color in zip(elec_pos, acq_ch, color): if acq_idx is not None: ch = ch - 1 # change back from channel numbers to indices annotate_spatial_map(pos, ch, color, annotation_style=annotation_style, fontsize=fontsize, marker=marker, markersize=markersize, ax=ax, **kwargs)
[docs]def plot_image_by_time(time, image_values, ylabel='trial', cmap='bwr', ax=None): ''' Makes an nt x ntrial image colored by the timeseries values. Example: :: time = np.array([-2, -1, 0, 1, 2, 3]) data = np.array([[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 1, 0]]).T plot_image_by_time(time, data) filename = 'image_by_time.png' .. image:: _images/image_by_time.png Args: time (nt): time vector to plot along the x axis image_values (nt, [nch or ntr]): time-by-trial or time-by-channel data ylabel (str, optional): description of the second axis of image_values. Defaults to 'trial'. cmap (str, optional): colormap with which to display the image. Defaults to 'bwr'. ax (pyplot.Axes, optional): Axes object on which to plot. Defaults to None. Returns: pyplot.AxesImage: the image object returned by pyplot ''' image_values = np.array(image_values) extent = [np.min(time), np.max(time), 0, image_values.shape[1]] # Plot if ax is None: ax = plt.gca() im = ax.imshow(image_values.T, cmap=cmap, origin='lower', extent=extent, aspect='auto', \ resample=False, interpolation='none', filternorm=False) ax.set_xlabel('time (s)') ax.set_ylabel(ylabel) return im
[docs]def plot_raster(data, cue_bin=None, ax=None): ''' Create a raster plot for binary input data and show the relative timing of an event with a vertical red line .. image:: _images/raster_plot_example.png Args: data (ntime, ncolumns): 2D array of data. Typically a time series of spiking events across channels or trials (not spike count- must contain only 0 or 1). cue_bin (float): time bin at which an event occurs. Leave as 'None' to only plot data. For example: Use this to indicate 'Go Cue' or 'Leave center' timing. ax (plt.Axis): axis to plot raster plot Returns: None: raster plot plotted in appropriate axis ''' if ax is None: ax = plt.gca() ax.eventplot(data.T, color='black') if cue_bin is not None: ax.axvline(x=cue_bin, linewidth=2.5, color='r')
[docs]def plot_angles(angles, magnitudes=None, ax=None, **kwargs): ''' Polar plot of angles and optional magnitudes. Useful for plotting ITPC or other phase data. Args: angles (nt): array of angles in radians magnitudes (nt, optional): array of magnitudes to plot as line lengths ax (plt.Axis, optional): axis to plot the targets on (should be a polar plot) **kwargs: additional keyword arguments to pass to plt.plot Examples: .. code-block:: python angles = np.linspace(np.pi/8, 2*np.pi + np.pi/8, 8, endpoint=False) plot_angles(angles) .. image:: _images/angles_simple.png .. code-block:: python angles = np.linspace(np.pi/8, 2*np.pi + np.pi/8, 8, endpoint=False) magnitudes = np.arange(len(angles)) + 1 fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}) plot_angles(angles, magnitudes, ax) .. image:: _images/angles_magnitudes.png ''' if ax is None and plt.gca().name != 'polar': fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}) elif ax is None: ax = plt.gca() if magnitudes is None: magnitudes = np.ones(len(angles)) # Draw the lines for a, m in zip(angles, magnitudes): theta = [0, a] r = [0, m] ax.plot(theta, r, **kwargs)
[docs]def set_bounds(bounds, ax=None): ''' Sets the x, y, and z limits according to the given bounds Args: bounds (tuple): 6-element tuple describing (-x, x, -y, y, -z, z) cursor bounds ax (plt.Axis, optional): axis to plot the targets on ''' if ax is None: ax = plt.gca() try: ax.set(xlim=(1.1 * bounds[0], 1.1 * bounds[1]), ylim=(1.1 * bounds[2], 1.1 * bounds[3]), zlim=(1.1 * bounds[4], 1.1 * bounds[5])) except: ax.set(xlim=(1.1 * bounds[0], 1.1 * bounds[1]), ylim=(1.1 * bounds[2], 1.1 * bounds[3]))
[docs]def color_targets(target_locations, target_idx, colors, target_radius, bounds=None, ax=None, **kwargs): ''' Color targets according to their index. Useful for visualizing unique targets when trajectories aren't obviously aligned to specific targets. Args: target_locations ((ntargets, 2) or (ntargets, 3) array): array of target (x, y[, z]) locations target_idx ((ntargets,) array): array of indices for each target, used to determine color colors (list): list of colors corresponding to each unique index in target_idx target_radius (float): radius of the targets in cm bounds (tuple, optional): 4- or 6-element tuple describing (-x, x, -y, y[, -z, z]) cursor bounds ax (plt.Axis, optional): axis to plot the targets on (2D or 3D) **kwargs: additional keyword arguments to pass to plot_circles() Examples: Create and plot eight targets for a center-out task. .. code-block:: python angles = np.linspace(0, 2*np.pi, 8, endpoint=False) radius = 6.5 target_locations = np.column_stack((radius * np.cos(angles), radius * np.sin(angles))) target_locations = np.vstack(([0, 0], target_locations)) Specify the colors per target index in case they are out of order. .. code-block:: python target_idx = [0] + np.arange(1, 9).tolist() # Center is index 0, peripheral are index 1 through 9 colors = ['black'] + sns.color_palette("husl", 8) target_radius = 0.5 bounds = (-8, 8, -8, 8) Use :func:`~aopy.visualization.color_targets` to plot the targets .. code-block:: python fig, ax = plt.subplots(figsize=(8, 8)) color_targets(target_locations, target_idx, colors, target_radius, bounds, ax) ax.set_aspect('equal') filename = 'color_targets.png' .. image:: _images/color_targets.png ''' assert len(target_locations) == len(target_idx), "Target locations must be the same length as target indices" target_locations = np.array(np.array(target_locations).tolist()) target_idx = np.array(np.array(target_idx).tolist()) loc_idx = np.concatenate((np.expand_dims(target_idx, 1), target_locations), axis=1) loc_idx = np.unique(loc_idx, axis=0) assert len(colors) >= len(np.unique(target_idx)), "Not enough colors for unique target indices" for row in loc_idx: idx = row[0].astype(int) loc = row[1:] plot_circles([loc], target_radius, colors[idx], bounds=bounds, ax=ax, **kwargs)
[docs]def plot_targets(target_positions, target_radius, bounds=None, alpha=0.5, origin=(0, 0, 0), ax=None, unique_only=True): ''' Add targets to an axis. If any targets are at the origin, they will appear in a different color (magenta). Works for 2D and 3D axes Example: Plot four peripheral and one central target. :: target_position = np.array([ [0, 0, 0], [1, 1, 0], [-1, 1, 0], [1, -1, 0], [-1, -1, 0] ]) target_radius = 0.1 plot_targets(target_position, target_radius, (-2, 2, -2, 2)) .. image:: _images/targets.png Args: target_positions (ntarg, 3): array of target (x, y, z) locations target_radius (float): radius of each target bounds (tuple, optional): 6-element tuple describing (-x, x, -y, y, -z, z) cursor bounds origin (tuple, optional): (x, y, z) position of the origin ax (plt.Axis, optional): axis to plot the targets on unique_only (bool, optional): If True, function will only plot targets with unique positions (default: True) ''' if unique_only: target_positions = np.unique(target_positions,axis=0) if isinstance(alpha,float): alpha = alpha * np.ones(len(target_positions)) else: assert len(alpha) == len(target_positions), "list of alpha values must be equal in length to the list of targets." if ax is None: ax = plt.gca() if unique_only: target_positions = np.unique(target_positions,axis=0) for i in range(len(target_positions)): # Pad the vector to make sure it is length 3 pos = np.zeros((3,)) pos[:len(target_positions[i])] = target_positions[i] # Color according to its position if (pos == origin).all(): target_color = 'm' else: target_color = 'b' plot_circles([pos], target_radius, target_color, bounds, alpha[i], ax, unique_only=False)
[docs]def plot_circles(circle_positions, circle_radius, circle_color='b', bounds=None, alpha=0.5, ax=None, unique_only=True): ''' Add circles to an axis. Works for 2D and 3D axes Args: circle_positions (ntarg, 3): array of target (x, y, z) locations circle_radius (float): radius of each target circle_color (str): color to draw circle - default is blue bounds (tuple, optional): 6-element tuple describing (-x, x, -y, y, -z, z) cursor bounds origin (tuple, optional): (x, y, z) position of the origin ax (plt.Axis, optional): axis to plot the targets on unique_only (bool, optional): If True, function will only plot targets with unique positions (default: True) ''' if unique_only: circle_positions = np.unique(circle_positions,axis=0) if isinstance(alpha,float): alpha = alpha * np.ones(len(circle_positions)) else: assert len(alpha) == len(circle_positions), "list of alpha values must be equal in length to the list of cricles." if ax is None: ax = plt.gca() for i in range(0, len(circle_positions)): # Pad the vector to make sure it is length 3 pos = np.zeros((3,)) pos[:len(circle_positions[i])] = circle_positions[i] # Plot in 3D or 2D ax.set_xlabel('x') ax.set_ylabel('y') try: ax.set_zlabel('z') u = np.linspace(0, 2 * np.pi, 100) v = np.linspace(0, np.pi, 100) x = pos[0] + circle_radius * np.outer(np.cos(u), np.sin(v)) y = pos[1] + circle_radius * np.outer(np.sin(u), np.sin(v)) z = pos[2] + circle_radius * np.outer(np.ones(np.size(u)), np.cos(v)) ax.plot_surface(x, y, z, alpha=alpha[i], color=circle_color) ax.set_box_aspect((1, 1, 1)) except: target = plt.Circle((pos[0], pos[1]), radius=circle_radius, alpha=alpha[i], color=circle_color) ax.add_artist(target) ax.set_aspect('equal', adjustable='box') if bounds is not None: set_bounds(bounds, ax)
[docs]def plot_3D_as_2D(trajectories, ax): """ Flattens 3D trajectory data to 2D for plotting and sets appropriate axis labels. This helper function converts a list of 3D trajectories into 2D by identifying and removing one axis. All-zero axes are removed if they exist, otherwise the z-axis is removed. Axis labels are set based on which dimensions are retained. Args: trajectories (list of numpy.ndarray): A list of trajectory arrays, each of shape (T, D), where T is the number of time steps and D is the dimensionality (typically 3). ax (matplotlib.axes.Axes): The 2D axis on which the trajectories will be plotted. Returns: numpy.ndarray: An object array of 2D trajectories, each of shape (T, 2), with one axis removed. Example: traj1 = np.array([[0, 0, 0], [1, 2, 0], [2, 4, 0]]) traj2 = np.array([[0, 0, 0], [1, 1, 0], [2, 2, 0]]) fig, ax = plt.subplots() flat_trajs = plot_3D_as_2D([traj1, traj2], ax) for traj in flat_trajs: ax.plot(traj[:, 0], traj[:, 1]) plt.show() """ stacked = np.vstack(trajectories) zero_cols = np.all(stacked==0, axis=0) if stacked.shape[1]>2: if not zero_cols.any(): warnings.warn("Axis is unclear for 3D data (no zero columns). Plots assume data of interest is in xy-plane.", stacklevel=2) zero_col_idx = [2] # assume z-axis data is unwanted col_to_remove = zero_col_idx[0] # tag z-axis data for removal else: zero_col_idx = np.where(zero_cols)[0] if len(zero_col_idx) > 1: # 1D data; remove 2nd all-zero axis col_to_remove = zero_col_idx[1] else: col_to_remove = zero_col_idx[0] # 2D data; remove only all-zero axis flattened = [traj[:, [i for i in range(3) if i != col_to_remove]] for traj in trajectories] flattened = np.array(flattened, dtype=object) if np.all(np.isin(2, zero_col_idx)): ax.set_xlabel('x') ax.set_ylabel('y') elif np.all(np.isin(1, zero_col_idx)): ax.set_xlabel('x') ax.set_ylabel('z') else: ax.set_xlabel('z') ax.set_ylabel('y') else: flattened = trajectories ax.set_xlabel('x') ax.set_ylabel('y') return flattened
[docs]def plot_trajectories(trajectories, bounds=None, ax=None, **kwargs): ''' Draws the given trajectories, one at a time in different colors. Works for 2D and 3D axes. If 2D axes are given with 3D data, dimensions of interest are inferred from zero-columns if present. Plotting 3D data with no zero-columns on a 2D axis will show the data in the xy-plane (first two dimensions). Example: Two random trajectories. :: trajectories =[ np.array([ [0, 0, 0], [1, 1, 0], [2, 2, 0], [3, 3, 0], [4, 2, 0] ]), np.array([ [-1, 1, 0], [-2, 2, 0], [-3, 3, 0], [-3, 4, 0] ]) ] bounds = (-5., 5., -5., 5., 0., 0.) plot_trajectories(trajectories, bounds) .. image:: _images/trajectories.png :: trajectories =[ np.array([ [0, 0, 0], [1, 0, 1], [2, 0, 2], [3, 0, 3], [4, 0, 2] ]), np.array([ [-1, 0, 1], [-2, 0, 2], [-3, 0, 3], [-3, 0, 4] ]) ] bounds = (-5., 5., -5., 5., 0., 0.) plot_trajectories(trajectories, bounds) .. image:: _images/trajectories_flat.png :: trajectories =[ np.array([ [0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 2, 0] ]), np.array([ [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0] ]) ] bounds = (-5., 5., -5., 5., 0., 0.) plot_trajectories(trajectories, bounds) .. image:: _images/trajectories_1D.png Args: trajectories (list): list of (n, 2) or (n, 3) trajectories where n can vary across each trajectory bounds (tuple, optional): 6-element tuple describing (-x, x, -y, y, -z, z) cursor bounds ax (plt.Axis, optional): axis to plot the targets on kwargs (dict): keyword arguments to pass to the plt.plot function ''' if ax is None: ax = plt.gca() # Plot in 3D, fall back to 2D and check axes try: ax.set_zlabel('z') ax.set_xlabel('x') ax.set_ylabel('y') for traj in trajectories: ax.plot(*traj.T, **kwargs) ax.set_box_aspect((1, 1, 1)) except: flattened = plot_3D_as_2D(trajectories, ax) for traj in flattened: ax.plot(traj[:, 0], traj[:, 1], **kwargs) ax.set_aspect('equal', adjustable='box') if bounds is not None: set_bounds(bounds, ax)
[docs]def color_trajectories(trajectories, labels, colors, ax=None, **kwargs): ''' Draws the given trajectories but with the color of each trajectory corresponding to its given label. Works for 2D and 3D axes Example: .. code-block:: python trajectories =[ np.array([ [0, 0, 0], [1, 1, 0], [2, 2, 0], [3, 3, 0], [4, 2, 0] ]), np.array([ [-1, 1, 0], [-2, 2, 0], [-3, 3, 0], [-3, 4, 0] ]) ] labels = [0, 0, 1, 0] colors = ['r', 'b'] color_trajectories(trajectories, labels, colors) .. image:: _images/color_trajectories_simple.png labels_list = [[0, 0, 1, 1, 1], [0, 0, 1, 1], [1, 1, 0, 0]] fig = plt.figure() color_trajectories(trajectories, labels_list, colors) .. image:: _images/color_trajectories_segmented.png Args: trajectories (ntrials): list of (n, 2) or (n, 3) trajectories where n can vary across each trajectory labels (ntrials): integer array of labels for each trajectory. Basically an index for each trajectory colors (ncolors): list of colors. A list of arrays containing the label for each corresponding trajectory, or a list of lists where each sublist corresponds to the label for each timepoint in the corresponding trajectory. ax (plt.Axis, optional): axis to plot the targets on **kwargs (dict): other arguments for plot_trajectories(), e.g. bounds ''' # If the labels are in list of lists format, segment the trajectories accordingly if isinstance(labels[0], list) or isinstance(labels[0], np.ndarray): all_trajectories = [] all_labels = [] for t, l in zip(trajectories, labels): assert len(t) == len(l), "Input labels must be the same length as input trajectories" segmented_trajectories, segmented_labels = utils.segment_array(t, l, duplicate_endpoints=True) all_trajectories += segmented_trajectories all_labels += segmented_labels trajectories = all_trajectories labels = all_labels # Convert the labels to integers for indexing into the color list labels = np.array(labels).astype(int) # Initialize a cycler with the appropriate colors style = plt.cycler(color=[colors[i] for i in labels]) if ax is None: ax = plt.gca() ax.set_prop_cycle(style) # Use the regular trajectory plotting function plot_trajectories(trajectories, ax=ax, **kwargs)
[docs]def gradient_trajectories(trajectories, n_colors=100, color_palette='viridis', bounds=None, ax=None, **kwargs): ''' Draw trajectories with a gradient of color from start to end of each trajectory. Works in 2D and 3D. If 2D axes are given with 3D data, dimensions of interest are inferred from zero-columns if present. Plotting 3D data with no zero-columns on a 2D axis will show the data in the xy-plane (first two dimensions). Note: this function applies the gradient evenly across the timepoints of the trajectory. It might be useful to use the sampling rate of the data instead of n_colors, so that the time axis is consistent across sampling rates. Args: trajectories (ntrials): list of 2D or 3D trajectories, in x, y[, z] coordinates n_colors (int, optional): number of colors in the gradient. Default 100. color_palette (str, optional): colormap to use for the gradient. Default 'viridis'. bounds (tuple, optional): 6-element tuple describing (-x, x, -y, y, -z, z) axes bounds ax (plt.Axis, optional): axis to plot the targets on kwargs (dict): keyword arguments to pass to the LineCollection function (similar to plt.plot) Example: Cursor trajectories in 2D .. code-block:: python subject = 'beignet' te_id = 5974 date = '2022-07-01' preproc_dir = data_dir traj, _ = aopy.data.get_kinematic_segments(preproc_dir, subject, te_id, date, [32], [81, 82, 83, 239]) gradient_trajectories(traj[:3]) .. image:: _images/gradient_trajectories.png Hand trajectories in 3D .. code-block:: python traj, _ = aopy.data.get_kinematic_segments(preproc_dir, subject, te_id, date, [32], [81, 82, 83, 239], datatype='hand') plt.figure() ax = plt.axes(projection='3d') gradient_trajectories(traj[:3], bounds=[-10,0,60,70,20,40], ax=ax) .. image:: _images/gradient_trajectories_3d.png Note: Automatic bounds aren't set in 3D plots. The best alternative is to first plot in 2D, then use those bounds to manually set the first 2 axes bounds for the 3D plot. ''' color_list = sns.color_palette(color_palette, n_colors) if ax is None: ax = plt.gca() try: # check if 3D axes given ax.set_zlabel('z') ax.set_xlabel('x') ax.set_ylabel('y') for traj in trajectories: n_pt = len(traj) if n_pt < n_colors: warnings.warn("Not enough datapoints to divide into n_colors!") # Assign labels to the trajectory according to color labels = np.zeros((n_pt,), dtype='int') size = (n_pt // n_colors) * n_colors # largest size we can evenly split into n_colors labels[:size] = np.repeat(range(n_colors), n_pt // n_colors) labels[size:] = n_colors - 1 # leftovers also get the last color # Split the labeled trajectories into segments with unique colors segments, labels = utils.segment_array(traj, labels, duplicate_endpoints=True) labels = np.array(labels).astype(int) colors = [color_list[i] for i in labels] segments = [np.vstack([s[:,0], s[:,1], s[:,2]]).T for s in segments] lc = Line3DCollection(segments, colors=colors, **kwargs) ax.add_collection(lc) ax.set_box_aspect((1, 1, 1)) except: # 2D axes given flattened = plot_3D_as_2D(trajectories, ax) for traj in flattened: n_pt = len(traj) if n_pt < n_colors: warnings.warn("Not enough datapoints to divide into n_colors!") # Assign labels to the trajectory according to color labels = np.zeros((n_pt,), dtype='int') size = (n_pt // n_colors) * n_colors # largest size we can evenly split into n_colors labels[:size] = np.repeat(range(n_colors), n_pt // n_colors) labels[size:] = n_colors - 1 # leftovers also get the last color # Split the labeled trajectories into segments with unique colors segments, labels = utils.segment_array(traj, labels, duplicate_endpoints=True) labels = np.array(labels).astype(int) colors = [color_list[i] for i in labels] segments = [np.vstack([s[:,0], s[:,1]]).T for s in segments] lc = LineCollection(segments, colors=colors, **kwargs) ax.add_collection(lc) ax.set_aspect('equal', adjustable='box') if bounds is not None: set_bounds(bounds, ax) else: ax.margins(0.05) # The ax.add_collection() call doesn't automatically set margins
[docs]def plot_sessions_by_date(trials, dates, *columns, method='sum', labels=None, ax=None): ''' Plot session data organized by date and aggregated such that if there are multiple rows on a given date they are combined into a single value using the given method. If the method is 'mean' then the values will be averaged for each day, for example for size of cursor. The average is weighted by the number of trials in that session. If the method is 'sum' then the values will be added together on each day, for example for number of trials. Example: Plotting success rate averaged across days. .. code-block:: python from datetime import date, timedelta date = [date.today() - timedelta(days=1), date.today() - timedelta(days=1), date.today()] success = [70, 65, 65] trials = [10, 20, 10] fig, ax = plt.subplots(1,1) plot_sessions_by_date(trials, dates, success, method='mean', labels=['success rate'], ax=ax) ax.set_ylabel('success (%)') .. image:: _images/sessions_by_date.png Args: trials (nsessions): dates (nsessions): *columns (nsessions): dataframe columns or numpy arrays to plot method (str, optional): how to combine data within a single date. Can be 'sum' or 'mean'. labels (list, optional): string label for each column to go into the legend ax (pyplot.Axes, optional): axis on which to plot ''' dates = np.array(dates) first_day = np.min(dates) last_day = np.max(dates) plot_days = pd.date_range(start=first_day, end=last_day).to_list() n_columns = len(columns) n_days = len(plot_days) aggregate = np.zeros((n_columns, n_days)) for idx_day in range(n_days): day = plot_days[idx_day] for idx_column in range(n_columns): values = np.array(columns[idx_column])[dates == day.date()] try: if method == 'sum': if len(values) > 0: aggregate[idx_column, idx_day] = np.sum(values) else: aggregate[idx_column, idx_day] = np.nan elif method == 'mean': day_trials = np.array(trials)[dates == day.date()] aggregate[idx_column, idx_day] = np.average(values, weights=day_trials) else: raise ValueError("Unknown method for combining data") except: aggregate[idx_column, idx_day] = np.nan if ax == None: ax = plt.gca() for idx_column in range(n_columns): if hasattr(columns[idx_column], 'name'): ax.plot(plot_days, aggregate[idx_column,:], '.-', label=columns[idx_column].name) else: ax.plot(plot_days, aggregate[idx_column,:], '.-') ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=(mdates.MO, mdates.TU, mdates.WE, mdates.TH, mdates.FR, mdates.SA, mdates.SU))) ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) plt.setp(ax.get_xticklabels(), rotation=80) if labels: ax.legend(labels) else: ax.legend()
[docs]def plot_sessions_by_trial(trials, *columns, dates=None, smoothing_window=None, labels=None, ax=None, **kwargs): ''' Plot session data by absolute number of trials completed. Optionally split up the sessions by date and apply smoothing to each day's data. Example: Plotting success rate over three sessions. .. code-block:: python success = [70, 65, 60] trials = [10, 20, 10] fig, ax = plt.subplots(1,1) plot_sessions_by_trial(trials, success, labels=['success rate'], ax=ax) ax.set_ylabel('success (%)') .. image:: _images/sessions_by_trial.png Args: trials (nsessions): number of trials in each session *columns (nsessions): dataframe columns or numpy arrays to plot dates (nsessions, optional): dataframe columns or numpy arrays of the date of each session smoothing_window (int, optional): number of trials to smooth. Default no smoothing. labels (list, optional): string label for each column to go into the legend ax (pyplot.Axes, optional): axis on which to plot ''' if ax == None: ax = plt.gca() date_chg = [] if dates is not None: trial_dates = np.repeat(np.array(dates), trials) date_chg = np.insert(np.where(np.diff(trial_dates) > timedelta(0))[0] + 1, 0, 0) for idx_column in range(len(columns)): # Accumulate individual trials with the values given for each session values = np.array(columns[idx_column]) trial_values = np.repeat(values, trials) # Apply smoothing if smoothing_window is not None and dates is not None: split = np.split(trial_values, date_chg[1:]) split = [analysis.calc_rolling_average(s, window_size=smoothing_window, mode='nan') for s in split] trial_values = np.concatenate(split) elif smoothing_window is not None: trial_values = analysis.calc_rolling_average(trial_values, window_size=smoothing_window) # Plot with additional kwargs if hasattr(columns[idx_column], 'name'): ax.plot(trial_values, label=columns[idx_column].name, **kwargs) else: ax.plot(trial_values, **kwargs) # Add date labels for i in date_chg: date = trial_dates[i] ax.axvline(i, ymin=0, ymax=1, color='gray', alpha=0.5, linestyle='dashed') ax.text(i, 1, str(date), color='gray', rotation=90, ha='left', va='top', transform=ax.get_xaxis_transform()) ax.set_xlabel('trials') if labels is not None: ax.legend(labels) else: ax.legend()
[docs]def plot_events_time(events, event_timestamps, labels, ax=None, colors=['tab:blue','tab:orange','tab:green']): ''' This function plots multiple different events on the same plot. The first event (item in the list) will be displayed on the bottom of the plot. .. image:: _images/events_time.png Args: events (list (nevents) of 1D arrays (ntime)): List of Logical arrays that denote when an event(for example, a reward) occurred during an experimental session. Each item in the list corresponds to a different event to plot. event_timestamps (list (nevents) of 1D arrays ntime): List of 1D arrays of timestamps corresponding to the events list. labels (list (nevents) of str) : Event names for each list item. ax (axes handle): Axes to plot colors (list of str): Color to use for each list item ''' if ax is None: ax = plt.gca() n_events = len(events) for i in range(n_events): this_events = events[i] this_timestamps = event_timestamps[i] n_timebins = np.shape(this_events)[0] if n_events <= len(colors): this_color = colors[i] ax.step(this_timestamps, 0.9*(this_events)+i+0.1, where='post', c=this_color) else: ax.step(this_timestamps, 0.9*(this_events)+i+0.1, where='post') ax.set_yticks(np.arange(n_events)+0.5) ax.set_yticklabels(labels) ax.set_xlabel('Time (s)')
[docs]def plot_waveforms(waveforms, samplerate, plot_mean=True, ax=None): ''' This function plots the input waveforms on the same figure and can overlay the mean if requested .. image:: _images/waveform_plot_example.png Args: waveforms (nt, nwfs): Array of waveforms to plot samplerate (float): Sampling rate of waveforms to calculate time axis. [Hz] plot_mean (bool): Indicate if the mean waveform should be plotted. Defaults to plot mean. ax (axes handle): Axes to plot ''' if ax is None: ax = plt.gca() time_axis = (1e6)*np.arange(waveforms.shape[0])/samplerate if plot_mean: ax.plot(time_axis, waveforms, color='black', alpha=0.5) mean_waveform = np.nanmean(waveforms, axis=1) ax.plot(time_axis, mean_waveform, color='red') else: ax.plot(time_axis, waveforms) ax.set_xlabel(r'Time ($\mu$s)')
[docs]def plot_condition_tuning(per_condition_data, conditions, ylabel='success rate', ax=None, **kwargs): ''' Plot tuning curves for categorical data. Essentially a scatter plot with the mean of each condition plotted as a solid line. Args: per_condition_data (nconditions, ...): data for each condition conditions (nconditions): condition for each data point ylabel (str, optional): label for the y-axis. Default "success rate" ax (pyplot.Axes, optional): axis to plot the tuning curves on. Default the current axis. Examples: Plot the success rate for 4 different conditions .. code-block:: python direction = [-np.pi, -np.pi/2, 0, np.pi/2] data = np.random.normal(0, 1, (4, 2, 4)) fig = plt.figure() plot_condition_tuning(data, np.degrees(direction)) .. image:: _images/condition_tuning.png ''' if ax is None: ax = plt.gca() per_condition_data = np.array(per_condition_data).reshape(len(per_condition_data), -1) ntr = per_condition_data.shape[1] print(per_condition_data.shape) print(conditions.shape) # Scatter plot plt.scatter(np.tile(conditions, (1,ntr)), per_condition_data, **kwargs) # Add means dist = (np.max(conditions) - np.min(conditions)) / len(conditions) for idx, cond in enumerate(per_condition_data): mean = np.mean(cond) ax.plot([conditions[idx]-dist/2, conditions[idx]+dist/2], [mean, mean], 'r-') ax.set_xlabel('condition') ax.set_ylabel(ylabel)
[docs]def plot_direction_tuning(per_direction_data, directions, show_var=True, wrap=True, ylabel='success rate', ax=None): ''' Plot tuning curves for directional data. The mean across trials is plotted as a solid line and the variance as a shaded region around the mean. Works with both cartesian and polar axes. Args: per_direction_data (ndir, nch, ntrial): direction responses for each channel. If only one channel, can be (ndir, ntrial). directions (ndir): unique directions in radians show_var (bool, optional): if True, plots the standard deviation around the mean. Default True. wrap (bool, optional): if True, duplicates the first value to wrap the plot around a circle. Default True. ylabel (str, optional): label for the y-axis. Default "success rate" ax (pyplot.Axes, optional): axis to plot the tuning curves on. Can be cartesian or polar. Default the current axis. Example: Polar plot of tuning curves for 4 targets .. code-block:: python direction = [-np.pi, -np.pi/2, 0, np.pi/2] data = np.random.normal(0, 1, (4, 2, 4)) plt.figure() plot_direction_tuning(data, direction) .. image:: _images/direction_tuning.png Again but with polar plot .. code-block:: python fig = plt.figure() ax = fig.add_subplot(projection='polar') plot_direction_tuning(data, direction) .. image:: _images/direction_tuning_polar.png ''' if ax is None: ax = plt.gca() if np.ndim(per_direction_data) == 1: per_direction_data = np.expand_dims(per_direction_data, 1) if np.ndim(per_direction_data) == 2: per_direction_data = np.expand_dims(per_direction_data, 1) if len(directions) != len(per_direction_data): directions = np.unique(directions) assert len(directions) == len(per_direction_data), "Direction and mean must have the same length" # Calculate mean and variance mean = np.nanmean(per_direction_data, axis=2) if show_var: var = np.nanstd(per_direction_data, axis=2) else: var = np.zeros_like(mean) # Sort the data and decide if the data fills a full circle or half circle if np.max(np.abs(directions)) > 2*np.pi: directions = np.radians(directions) # probably in degrees by mistake modulo = np.pi if np.max(directions) - np.min(directions) >= (np.pi): modulo = np.pi * 2 directions = np.array(directions) % modulo idx = np.argsort(directions) # Wrap around the circle if wrap: directions = np.hstack((directions[idx], [directions[idx[0]] + modulo])) mean = np.vstack((mean[idx], [mean[idx[0]]])) var = np.vstack((var[idx], [var[idx[0]]])) else: directions = directions[idx] mean = mean[idx] var = var[idx] # Plot if ax.name != 'polar': directions = np.degrees(directions) for ch in range(mean.shape[1]): ax.plot(directions, mean[:,ch]) ax.fill_between(directions, mean[:,ch]-var[:,ch], mean[:,ch]+var[:,ch], alpha=0.5) try: label_position=ax.get_rlabel_position() ax.text(np.radians(label_position-1),ax.get_rmax()*1.1,ylabel, rotation=label_position,ha='left',va='center') except: ax.set_xlabel('direction (deg)') ax.set_ylabel(ylabel)
[docs]def plot_tuning_curves(fit_params, mean_fr, targets, n_subplot_cols=5, ax=None): ''' This function plots the tuning curves output from analysis.run_tuningcurve_fit overlaying the actual firing rate data. The dashed line is the model fit and the solid line is the actual data. .. image:: _images/tuning_curves_plot.png Args: fit_params (nunits, 3): Model fit coefficients. Output from analysis.run_tuningcurve_fit or analysis.curve_fitting_func mean_fr (nunits, ntargets): The average firing rate for each unit for each target. target_theta (ntargets): Orientation of each target in a center out task [degrees]. Corresponds to order of targets in 'mean_fr' n_subplot_cols (int): Number of columns to plot in subplot. This function will automatically calculate the number of rows. Defaults to 5 ax (axes handle): Axes to plot ''' nunits = mean_fr.shape[0] n_subplot_rows = ((nunits-1)//n_subplot_cols)+1 axinput = True if ax is None: fig, ax = plt.subplots(n_subplot_rows, n_subplot_cols) axinput = False nplots = n_subplot_rows*n_subplot_cols for iunit in range(nplots): if nunits > n_subplot_cols and n_subplot_cols!=1: nrow = iunit//n_subplot_cols ncol = iunit - (nrow*n_subplot_cols) # Remove axis that aren't used if iunit >= nunits: ax[nrow, ncol].remove() else: ax[nrow, ncol].plot(targets, mean_fr[iunit,:], 'b-', label='data') ax[nrow, ncol].plot(targets, analysis.curve_fitting_func(targets, fit_params[iunit, 0], fit_params[iunit, 1], fit_params[iunit,2]), 'b--', label='fit') ax[nrow, ncol].set_title('Unit ' +str(iunit)) else: # Remove axis that aren't used if iunit >= nunits: ax[iunit].remove() else: ax[iunit].plot(targets, mean_fr[iunit,:], 'b-', label='data') ax[iunit].plot(targets, analysis.curve_fitting_func(targets, fit_params[iunit, 0], fit_params[iunit, 1], fit_params[iunit,2]), 'b--', label='fit') ax[iunit].set_title('Unit ' +str(iunit)) if not axinput: fig.tight_layout()
[docs]def plot_boxplots(data, plt_xaxis, trendline=True, facecolor='gray', linecolor='k', box_width=0.5, label_xticks=True, ax=None): ''' This function creates a boxplot for each column of input data. If the input data has NaNs, they are ignored. Args: data (ncol list or (m, ncol) array): Data to plot. A different boxplot is created for each entry of the list. plt_xaxis (ncol): X-axis locations or labels to plot the boxplot of each column trendline (bool): If a line should be used to connect boxplots facecolor (color): Color of the box faces. Can be any input that pyplot interprets as a color. linecolor (color): Color of the connecting lines. label_xticks(bool): If the values of 'plt_xaxis' should be used to label the xticks. If multiple boxplots are plotted on the same figure this should be set to False. ax (axes handle): Axes to plot Examples: Using a rectangular array and numeric x-axis points. .. code-block:: python data = np.random.normal(0, 2, size=(20, 5)) xaxis_pts = np.array([2,3,4,4.75,5.5]) fig, ax = plt.subplots(1,1) plot_boxplots(data, xaxis_pts, ax=ax) .. image:: _images/boxplot_example.png Using a list of nonrectangular arrays with categorical x-axis points. .. code-block:: python data = [np.random.normal(0, 2, size=(10)), np.random.normal(0, 1, size=(20))] xaxis_pts = ['foo', 'bar'] fig, ax = plt.subplots(1,1) plot_boxplots(data, xaxis_pts, ax=ax) .. image:: _images/boxplot_example_nonrectangular.png ''' if ax is None: ax = plt.gca() # If data is 2D, turn the columns into lists if hasattr(data, 'ndim') and data.ndim == 2: data = [data[:,i] for i in range(data.shape[1])] # If data is a single column, make it a list try: int(data[0]) data = [data] except: pass if trendline: ax.plot(plt_xaxis, [np.nanmedian(data[i]) for i in range(len(data))], color=facecolor) for featidx, ifeat in enumerate(plt_xaxis): temp_data = data[featidx] try: int(ifeat) except: ifeat = featidx ax.boxplot(temp_data[~np.isnan(temp_data)], positions=np.array([ifeat]), patch_artist=True, widths=box_width, boxprops=dict(facecolor=facecolor, color=linecolor), capprops=dict(color=linecolor), whiskerprops=dict(color=linecolor), flierprops=dict(color=facecolor, markeredgecolor=facecolor), medianprops=dict(color=linecolor)) if label_xticks: ax.set_xticklabels(plt_xaxis)
[docs]def advance_plot_color(ax, n): ''' Utility to skip colors for the given axis. Args: ax (pyplot.Axes): specify which axis to advance the color n (int): how many colors to skip in the cycle Examples: Using advance_plot_color to skip the first color in the cycle. .. code-block:: python plt.subplots() aopy.visualization.advance_plot_color(plt.gca(), 1) plt.plot(np.arange(10), np.arange(10)) .. image:: _images/advance_plot_color.png ''' if n <= 0: return # Matplotlib's internal color cycle handling changed; prefer the helper if available. get_next_color = getattr(getattr(ax, "_get_lines", None), "get_next_color", None) if callable(get_next_color): for _ in range(n): get_next_color() return # Fallback for older Matplotlib versions. for _ in range(n): next(ax._get_lines.prop_cycler)
[docs]def reset_plot_color(ax): ''' Utility to reset the color cycle on a given axis to the default. Args: ax (pyplot.Axes): specify which axis to reset the color Examples: Using reset_plot_color to reset the color cycle between calls to `plt.plot()`. .. code-block:: python plt.subplots() plt.plot(np.arange(10), np.ones(10)) aopy.visualization.reset_plot_color(plt.gca()) plt.plot(np.arange(10), 1 + np.ones(10)) .. image:: _images/reset_plot_color.png ''' ax.set_prop_cycle(None)
[docs]def plot_scalebar(ax, size, label, color='black', fontsize=12, vertical=False, bbox_to_anchor=[0.1, 0.1], **kwargs): ''' Add a scalebar to a plot with the given size and label. The scalebar can be vertical or horizontal. The left edge (bottom edge if vertical) of the scalebar will be located at the given bbox_to_anchor position in Axis units (0 to 1). Args: ax (pyplot.Axes): axis to plot the scalebar on size (float): size of the scalebar in units of the plot label (str): label for the scalebar, e.g. '1 s' or '10 um' color (str): color of the scalebar. Can be any input that pyplot interprets as a color. fontsize (int): size of the font for the label vertical (bool): If True, the scalebar will be vertical. Default is horizontal. bbox_to_anchor (tuple): (x, y) position of the scalebar in the plot in Axis units. Default is (0.1, 0.1). **kwargs: additional keyword arguments to pass to AnchoredSizeBar Examples: Adding a scalebar to a plot with a size of 10 and a label of '10 ms'. .. code-block:: python plt.subplots() plt.plot(np.arange(10), np.arange(10)/10) aopy.visualization.plot_scalebar(plt.gca(), 1.5, '1 s', color='orange') aopy.visualization.plot_scalebar(plt.gca(), 0.15, '0.1 V', vertical=True, color='green') aopy.visualization.plot_xy_scalebar(plt.gca(), 1.5, '1 s', 0.15, '0.1 V', bbox_to_anchor=(0.8, 0.1)) filename = 'scalebar_example.png' .. image:: _images/scalebar_example.png ''' if not vertical: xsize = size ysize = 0 label_top = False loc = 'upper left' else: xsize = 0 ysize = size label_top = True loc = 'lower center' # Draw the scalebar scalebar = AnchoredSizeBar( ax.transData, xsize, label, loc=loc, bbox_to_anchor=bbox_to_anchor, bbox_transform=ax.transAxes, pad=kwargs.pop('pad', 0), borderpad=kwargs.pop('borderpad', 0), sep=kwargs.pop('sep', 4), color=color, frameon=False, label_top=label_top, size_vertical=ysize, fontproperties=fm.FontProperties(size=fontsize), **kwargs ) ax.add_artist(scalebar)
[docs]def plot_xy_scalebar(ax, xsize, xlabel, ysize, ylabel, color='black', fontsize=12, bbox_to_anchor=[0.1, 0.1], **kwargs): ''' Shortcut to add two scalebars to a plot with the given x and y sizes and labels. Args: ax (pyplot.Axes): axis to plot the scalebar on xsize (float): size of the x scalebar xlabel (str): label for the x scalebar ysize (float): size of the y scalebar ylabel (str): label for the y scalebar color (str): color of the scalebar. Can be any input that pyplot interprets as a color. fontsize (int): size of the font for the label bbox_to_anchor (tuple): (x, y) position of the scalebar in the plot in Axis units. Default is (0.1, 0.1). **kwargs: additional keyword arguments to pass to AnchoredSizeBar See also: :func:`~aopy.visualization.plot_scalebar` ''' plot_scalebar(ax, xsize, xlabel, color=color, fontsize=fontsize, bbox_to_anchor=bbox_to_anchor, **kwargs) plot_scalebar(ax, ysize, ylabel, color=color, fontsize=fontsize, vertical=True, bbox_to_anchor=bbox_to_anchor, **kwargs)
[docs]def profile_data_channels(data, samplerate, figuredir, **kwargs): """ Runs `plot_channel_summary` and `combine_channel_figures` on all channels in a data array Args: data (nt, nch): numpy array of neural data samplerate (int): sampling rate of data figuredir (str): string indicating file path to desired save directory kwargs (**dict): keyword arguments to pass to plot_channel_summary() .. image:: _images/channel_profile_example.png """ if not os.path.exists(figuredir): os.makedirs(figuredir) _, nch = data.shape for chidx in tqdm(range(nch)): chname = f'ch. {chidx+1}' fig = plot_channel_summary(data[:,chidx], samplerate, title=chname, **kwargs) fig.savefig(os.path.join(figuredir,f'ch_{chidx}.png')) combine_channel_figures(figuredir, nch=nch, figsize=kwargs.pop('figsize', (6,5)), dpi=kwargs.pop('dpi', 150))
[docs]def combine_channel_figures(figuredir, nch=256, figsize=(6,5), dpi=150): """ Combines all channel figures in directory generated from plot_channel_summary Args: figuredir (str): path to directory of channel profile images nch (int, optional): number of channels from data array. Determines combined image layout. Defaults to 256. figsize (tuple, optional): (width, height) to pass to pyplot. Default (6, 5) dpi (int, optional): resolution to pass to pyplot. Default 150 """ assert os.path.exists(figuredir), f"Directory not found: {figuredir}" ncol = int(np.ceil(np.sqrt(nch))) # make things as square as possible nrow = int(np.ceil(nch/ncol)) imgw = figsize[0] * dpi # I should get these from the individual files... imgh = figsize[1] * dpi grid = Image.new(mode='RGB', size=(ncol*imgw, nrow*imgh)) print(f'profiling all {nch} channels...') for chidx in tqdm(range(nch)): figurefile = os.path.join(figuredir,f'ch_{chidx}.png') rowidx = chidx // ncol colidx = chidx % ncol if not os.path.exists(figurefile): continue else: with Image.open(figurefile) as img: grid.paste(img,box=(colidx*imgw, rowidx*imgh)) grid.save(os.path.join(figuredir,'all_ch.png'),'png')
[docs]def plot_channel_summary(chdata, samplerate, nperseg=None, noverlap=None, trange=None, title=None, figsize=(6, 5), dpi=150, frange=(0, 80), cmap_lim=(0, 40)): """ Plot time domain trace, spectrogram and normalized (z-scored) spectrogram. Computes spectrogram. :: --------------- | time series | |-------------| | spectrogram | |-------------| | norm sgram | --------------- Args: chdata (nt,1): neural recording data from a given channel (lfp, ecog, broadband) samplerate (int): data sampling rate nperseg (int): length of each spectrogram window (in samples) noverlap (int): number of samples shared between neighboring spectrogram windows (in samples) trange (tuple, optional): (min, max) time range to display. Default show the entire time series title (str, optional): print a title above the timeseries data. Default None figsize (tuple, optional): (width, height) to pass to pyplot. Default (6, 5) dpi (int, optional): resolution to pass to pyplot. Default 150 frange (tuple, optional): range of frequencies to display in spectrogram. Default (0, 80) cmap_lim (tuple, optional): clim to display in the spectrogram. Default (0, 40) Outputs: fig (Figure): Figure object """ assert len(chdata.shape) < 2, "Input data array must be 1d" time = np.arange(len(chdata))/samplerate if trange is None: trange = (time[0], time[-1]) if nperseg is None: nperseg = int(2*samplerate) if noverlap is None: noverlap = int(1.5*samplerate) f_sg, t_sg, sgram = signal.spectrogram( chdata, fs=samplerate, nperseg=nperseg, noverlap=noverlap, detrend='linear' ) log_sgram = np.log10(sgram) fig, ax = plt.subplots(3,1,figsize=figsize,dpi=dpi,constrained_layout=True,sharex=True) ax[0].plot(time, chdata) sg_pcm = ax[1].pcolormesh(t_sg,f_sg,10*log_sgram,vmin=cmap_lim[0],vmax=cmap_lim[1],shading='auto') ax[1].set_ylim(*frange) sg_cb = plt.colorbar(sg_pcm,ax=ax[1]) sg_cb.ax.set_ylabel('dB$\\mu$') sgn_pcm = ax[2].pcolormesh(t_sg,f_sg,zscore(log_sgram,axis=-1),vmin=-3,vmax=3,shading='auto',cmap='bwr') ax[2].set_ylim(*frange) sgn_cb = plt.colorbar(sgn_pcm,ax=ax[2]) sgn_cb.ax.set_ylabel('z-scored dB$\\mu$') ax[0].set_xlim(*trange) ax[0].set_ylabel('amp. ($\\mu V$)') ax[1].set_ylabel('freq. (Hz)') ax[2].set_ylabel('freq. (Hz)') ax[2].set_xlabel('time (s)') ax[0].set_title(title) return fig
[docs]def plot_corr_over_elec_distance(elec_data, elec_pos, ax=None, **kwargs): ''' Makes a plot of correlation vs electrode distance for the given data. Args: elec_data (nt, nelec): electrode data with nch corresponding to elec_pos elec_pos (nelec, 2): x, y position of each electrode ax (pyplot.Axes, optional): axis on which to plot kwargs (dict, optional): other arguments to supply to :func:`aopy.analysis.calc_corr_over_elec_distance` Example: Using the multichannel test data generator in utils, we get a phase-shifted sine wave in each channel. Assigning each channel i to an electrode with position (i, 0), the correlation across distances looks like this: .. code-block:: python duration = 0.5 samplerate = 1000 n_channels = 30 frequency = 100 amplitude = 0.5 acq_data = aopy.utils.generate_multichannel_test_signal(duration, samplerate, n_channels, frequency, amplitude) acq_ch = (np.arange(n_channels)+1).astype(int) elec_pos = np.stack((range(n_channels), np.zeros((n_channels,))), axis=-1) plt.figure() plot_corr_over_elec_distance(acq_data, acq_ch, elec_pos) .. image:: _images/corr_over_dist.png Updated: 2024-03-13 (LRS): Changed input from acq_data and acq_ch to elec_data. 2024-07-01 (LRS): Fixed default x-axis label units to mm. ''' if ax is None: ax = plt.gca() label = kwargs.pop('label', None) dist, corr = analysis.calc_corr_over_elec_distance(elec_data, elec_pos, **kwargs) ax.plot(dist, corr, label=label) ax.set_xlabel('binned electrode distance (mm)') ax.set_ylabel('correlation') ax.set_ylim(0,1)
[docs]def plot_corr_across_entries(preproc_dir, subjects, ids, dates, band=(70,200), taper_len=0.1, num_seconds=60, cmap='viridis', ax=None, remove_bad_ch=True, **bad_ch_kwargs): ''' Plot the correlation vs electrode distance for each entry in the given list of subjects, ids, and dates. Args: preproc_dir (str): path to the preprocessed data directory subjects (list): list of subject names ids (list): list of te_ids dates (list): list of dates band (tuple, optional): frequency band to filter the data. Default (70, 200) taper_len (float, optional): length of taper to use in the filter. Default 0.1 num_seconds (int, optional): number of seconds to use in the correlation calculation. Default 60 cmap (str, optional): colormap to use for plotting. Default 'viridis' ax (pyplot.Axes, optional): axis on which to plot. Default current axis remove_bad_ch (bool, optional): whether to remove bad channels from the data. Default True bad_ch_kwargs (dict, optional): keyword arguments to pass to :func:`a Example: Plotting the correlation vs electrode distance for a few entries in the preprocessed data directory. .. image:: _images/corr_over_entries.png ''' assert len(subjects) == len(ids) == len(dates), "Subjects, ids, and dates must be equal length" if ax is None: ax = plt.gca() ax.set_prop_cycle('color', sns.color_palette(cmap, len(subjects))) for subject, te_id, date in zip(subjects, ids, dates): try: lfp_data, lfp_metadata = aodata.load_preproc_lfp_data(preproc_dir, subject, te_id, date) exp_data, exp_metadata = aodata.load_preproc_exp_data(preproc_dir, subject, te_id, date) except: print(f"Could not find data for entry {te_id} ({subject} on {date})") continue try: elec_pos, acq_ch, _ = aodata.load_chmap(exp_metadata['drmap_drive_type']) except: elec_pos, acq_ch, _ = aodata.load_chmap('ECoG244') samplerate = lfp_metadata['samplerate'] short_data = lfp_data[:num_seconds*samplerate,acq_ch-1] filt_data = precondition.mt_bandpass_filter(short_data, band, taper_len, samplerate, verbose=False) if remove_bad_ch: bad_ch = preproc.quality.detect_bad_ch_outliers(filt_data, **bad_ch_kwargs) filt_data = filt_data[:,~bad_ch] elec_pos = elec_pos[~bad_ch] plot_corr_over_elec_distance(filt_data, elec_pos, label=date, ax=ax) leg = ax.legend(bbox_to_anchor = (1,1)) for obj in leg.legend_handles: obj.set_linewidth(4.0)
[docs]def plot_tfr(values, times, freqs, cmap='plasma', logscale=False, ax=None, **kwargs): ''' Plot a time-frequency representation of a signal. Args: values ((nt, nfreq) array): times ((nt,) array): freqs ((nfreq,) array): cmap (str, optional): colormap to use for plotting logscale (bool, optional): apply a log scale to the color axis. Default False. ax (pyplot.Axes, optional): axes on which to plot. Default current axis. kwargs (dict, optional): other keyword arguments to pass to pyplot Returns: pyplot.Image: image object returned from pyplot.pcolormesh. Useful for adding colorbars, etc. Examples: .. code-block:: python fig, ax = plt.subplots(3,1,figsize=(4,6)) samplerate = 1000 data_200_hz = aopy.utils.generate_multichannel_test_signal(2, samplerate, 8, 200, 2) nt = data_200_hz.shape[0] data_200_hz[:int(nt/3),:] /= 3 data_200_hz[int(2*nt/3):,:] *= 2 data_50_hz = aopy.utils.generate_multichannel_test_signal(2, samplerate, 8, 50, 2) data_50_hz[:int(nt/2),:] /= 2 data = data_50_hz + data_200_hz print(data.shape) aopy.visualization.plot_timeseries(data, samplerate, ax=ax[0]) aopy.visualization.plot_freq_domain_amplitude(data, samplerate, ax=ax[1]) freqs = np.linspace(1,250,100) coef = aopy.analysis.calc_cwt_tfr(data, freqs, samplerate, fb=10, f0_norm=1, verbose=True) t = np.arange(nt)/samplerate print(data.shape) print(coef.shape) print(t.shape) print(freqs.shape) pcm = aopy.visualization.plot_tfr(abs(coef[:,:,0]), t, freqs, 'plasma', ax=ax[2]) fig.colorbar(pcm, label='Power', orientation = 'horizontal', ax=ax[2]) .. image:: _images/tfr_cwt_50_200.png See Also: :func:`~aopy.analysis.calc_cwt_tfr` ''' if ax == None: ax = plt.gca() if logscale: pcm = ax.pcolormesh(times, freqs, np.log10(values), cmap=cmap, **kwargs) else: pcm = ax.pcolormesh(times, freqs, values, cmap=cmap, **kwargs) pcm.set_edgecolor('face') ax.set_xlabel('Time (s)') ax.set_ylabel('Frequency (Hz)') return pcm
[docs]def plot_tf_map_grid(freqs, time, tf_data, bands, elec_pos, clim=None, interp_grid=None, cmap='viridis', grid_size=(4,4), colorbar=True, **kwargs): ''' Plot a grid of different frequency bands and time points for a given time-frequency map across spatial locations. Args: freqs (nfreq,): frequency values time (nt,): time values tf_data (nfreq, nt, nch): time-frequency data across spatial channels bands (list): list of tuples of frequency bands to plot elec_pos (nch, 2): x, y position of each electrode clim (tuple, optional): color limits for the plot, e.g. (0,1) for tfcoh maps. Default None interp_grid (tuple, optional): (x, y) grid to interpolate the data onto. Default None cmap (str, optional): colormap to use for plotting. Default 'viridis' grid_size (tuple, optional): (width, height) in inches of each subplot grid. Default (4,4) kwargs (dict, optional): other keyword arguments to pass to calc_data_map Returns: list of pyplot.Axes: axes objects for each subplot in the grid Examples: Random power across space with increased power at time 1 and decreased power in high frequencies. .. code-block:: python nfreq = 100 nt = 3 nch = 100 freqs = np.linspace(1,250,nfreq) time = np.linspace(0, 1, nt) tf_data = np.random.rand(nfreq,nt,nch) tf_data[:,1,:] *= 2 # increase power at time 1 tf_data[freqs > 10, :, :] *= 0.5 # decrease power in high frequencies bands = [(1, 10), (10, 250)] x, y = np.meshgrid(np.arange(10), np.arange(10)) elec_pos = np.zeros((100,2)) elec_pos[:,0] = x.reshape(-1) elec_pos[:,1] = y.reshape(-1) plot_tf_map_grid(freqs, time, tf_data, bands, elec_pos, clim=(0,1), interp_grid=None, cmap='viridis') .. image:: _images/tf_map_grid.png ''' fig, ax = plt.subplots(len(bands), len(time), figsize=(grid_size[0]*len(time),grid_size[1]*len(bands)), layout='constrained') for band_idx, band in enumerate(bands): for t_idx, t in enumerate(time): if ax.ndim == 2: this_ax = ax[band_idx, t_idx] else: this_ax = ax[max(band_idx, t_idx)] this_tf_data = tf_data[(freqs > band[0]) & (freqs < band[1]),t_idx,:] if interp_grid is None: data_map = get_data_map( np.squeeze(np.mean(this_tf_data, axis=0)), elec_pos[:,0], elec_pos[:,1] ) else: data_map = calc_data_map( np.squeeze(np.mean(this_tf_data, axis=0)), elec_pos[:,0], elec_pos[:,1], interp_grid, **kwargs ) im = plot_spatial_map( data_map, elec_pos[:,0], elec_pos[:,1], cmap=cmap, ax=this_ax) if clim is not None: im.set_clim(clim) if colorbar: plt.colorbar(im, ax=this_ax, shrink=0.7) this_ax.set_title(f't={t:.2f}, {band[0]}-{band[1]} Hz') this_ax.set(xticks=[], yticks=[], xlabel='', ylabel='') return ax
[docs]def get_color_gradient_RGB(npts, end_color, start_color=[1,1,1]): ''' This function outputs an ordered list of RGB colors that are linearly spaced between white and the input color. See also sns.color_palette for a gradient of RGB values within a Seaborn color palette. Examples: .. code-block:: python npts = 200 x = np.linspace(0, 2*np.pi, npts) y = np.sin(x) fig, ax = plt.subplots() ax.scatter(x, y, c=get_color_gradient(npts, 'g', [1,0,0])) .. image:: _images/color_gradient_example.png Args: npts (int): How many different colors are part of the gradient end_color (str or list): Color that ends the gradient. Can be any matplotlib color or specific RGB values. start_color (str or list): Color that ends the gradient. Can be any matplotlib color or specific RGB values. Defaults to white. Returns: (npts, 3): An array with linearly spaced colors from the start to end ''' rgb_end = matplotlib.colors.to_rgb(end_color) rgb_start = matplotlib.colors.to_rgb(start_color) ct = np.zeros((npts, 3)) ct[:,0] = np.flip(np.linspace(rgb_end[0], rgb_start[0], npts)) ct[:,1] = np.flip(np.linspace(rgb_end[1], rgb_start[1], npts)) ct[:,2] = np.flip(np.linspace(rgb_end[2], rgb_start[2], npts)) return ct
[docs]def plot_laser_sensor_alignment(sensor_volts, samplerate, stim_times, ax=None): ''' Plot laser sensor data aligned to the stimulus times. Useful to debug laser timing issues to make sure the laser is actually on when you think it is. Args: sensor_volts ((nstim,) float array): laser sensor data samplerate (float): sampling rate of the sensor data stim_times ((nstim,) array): times at which the laser was turned on ax (pyplot.Axes, optional): axes on which to plot. Default current axis. kwargs (dict, optional): other keyword arguments to pass to pyplot Returns: pyplot.Image: image object returned from pyplot.pcolormesh. Useful for adding colorbars, etc. Examples: .. image:: _images/laser_sensor_alignment.png ''' if ax is None: ax = plt.gca() time_before = 0.1 # seconds time_after = 0.1 # seconds analog_erp = analysis.calc_erp(sensor_volts, stim_times, time_before, time_after, samplerate) t = 1000*(np.arange(analog_erp.shape[0])/samplerate - time_before) # milliseconds im = plot_image_by_time(t, analog_erp[:,0,:], ylabel='trials') plt.xlabel('time (ms)') plt.title('laser sensor aligned') return im
[docs]def plot_circular_hist(data, bins=16, density=False, offset=0, proportional_area=False, gaps=False, normalize=False, ax=None, **kwargs): ''' Plot a circular histogram of angles on a given ax. Adapted from: https://stackoverflow.com/questions/22562364/circular-polar-histogram-in-python. Args: data (arr): angles to plot, in radians. bins (int, optional): defines the number of equal-width bins in the range. Default is 16. density (bool, optional): whether to return the probability density function at each bin, instead of the number of samples (passed to np.histogram). Default is False. offset (float, optional): the offset for the location of the 0 direction, in radians. Default is 0. proportional_area (bool, optional): If True, plots bars proportional to area. If False, plots bars proportional to radius. Default is False. gaps (bool, optional): whether to allow gaps between bins. If True, the bins will only span the values of the data. If False, the bins are forced to partition the entire [-pi, pi] range. Default is False. normalize (bool, optional): whether to normalize the bin values such that the max value is 1. Default is False. ax (pyplot.Axes, optional): axes on which to plot. Should be an axis instance created with subplot_kw=dict(projection='polar'). Default current axis. kwargs (dict, optional): other keyword arguments to pass to ax.bar Returns: n (arr or list of arr): the number of values in each bin bins (arr): the edges of the bins patches (`.BarContainer` or list of a single `.Polygon`): container of individual artists used to create the histogram or list of such containers if there are multiple input datasets Examples: .. image:: _images/circular_histograms.png ''' if ax is None: ax = plt.gca() # Wrap angles to [-pi, pi) data = (data+np.pi) % (2*np.pi) - np.pi # Force bins to partition entire circle if not gaps: bins = np.linspace(-np.pi, np.pi, num=bins+1) # Bin data and record counts n, bins = np.histogram(data, bins=bins, density=density) # Compute width of each bin widths = np.diff(bins) # If indicated, plot frequency proportional to area if proportional_area: # Area to assign each bin area = n / data.size # Calculate corresponding bin radius radius = (area/np.pi) ** .5 # Remove ylabels for area plots (they are mostly obstructive) ax.set_yticks([]) # Otherwise plot frequency proportional to radius else: radius = n # If indicated, normalize the bar values so that the max is 1 if normalize: radius = radius/np.max(radius) # Plot data on ax patches = ax.bar(bins[:-1], radius, width=widths, align='edge', **kwargs) # Set the direction of the zero angle ax.set_theta_offset(offset) return n, bins, patches
[docs]def overlay_image_on_spatial_map(filepath, drive_type, theta=0, center=(0,0), color=None, invert=False, ax=None, **kwargs): ''' Overlay an image on a spatial map of electrodes. The image is rotated by theta degrees and placed at the same coordinates as electrode positions for the given electrode drive. The image is also optionally inverted and recolored with a single input color. Args: filepath (str): path to the image file drive_type (str): drive type to use for the spatial map. See :func:`aopy.data.load_chmap` for options. theta (int, optional): rotation of the image in degrees. Default is 0. center (2-tuple): coordinates where the drive is centered on the brain (in mm). Default (0,0). color (str, optional): color to use for the image. Default is None. invert (bool, optional): whether to invert the image. Default is False. ax (pyplot.Axes, optional): axes on which to plot. Default current axis. kwargs (dict, optional): other keyword arguments to pass to ax.imshow, e.g. alpha. ''' if ax is None: ax = plt.gca() with Image.open(filepath) as im: a = np.array(im.convert('L'))/255 if color is None: c = np.array(im.convert('RGB'))/255 img = np.zeros((*a.shape, 4)) img[:,:,:3] = c else: img = np.zeros((*a.shape, 4)) img[:,:,:3] = colors.to_rgb(color) img[:,:,3] = a if invert else 1 - a img = np.rot90(img, np.ceil(theta/90), axes=(1,0)) # Calculate the proper extents elec_pos, _, _ = aodata.load_chmap(drive_type, theta=theta, center=center) x = elec_pos[:,0] y = elec_pos[:,1] extent = [np.min(x), np.max(x), np.min(y), np.max(y)] x_spacing = (extent[1] - extent[0]) / (len(np.unique(x)) - 1) y_spacing = (extent[3] - extent[2]) / (len(np.unique(y)) - 1) extent = np.add(extent, [-x_spacing / 2, x_spacing / 2, -y_spacing / 2, y_spacing / 2]) ax.imshow(img, origin='upper', extent=extent, **kwargs)
[docs]def overlay_sulci_on_spatial_map(subject, chamber, drive_type, theta=0, center=(0,0), alpha=0.5, **kwargs): ''' Overlay a precomputed image of chamber sucli on a spatial map of electrodes. Images are stored in the aopy.config directory. Currently available images are: - Affi LM1 ECoG244 - Beignet LM1 ECoG244 Args: subject (str): subject name chamber (str): chamber type drive_type (str): drive type of the spatial map. See :func:`~aopy.data.load_chmap` for options. theta (int, optional): rotation of the image in degrees. Default is 0. center (2-tuple): coordinates where the drive is centered on the brain (in mm). Default (0,0). alpha (float, optional): transparency of the image. Default is 0.5. kwargs (dict, optional): other keyword arguments to pass to ax.imshow, e.g. color. Examples: .. code-block:: python elec_pos, acq_ch, elecs = aodata.load_chmap('ECoG244') plot_spatial_map(np.arange(16*16).reshape((16,16)), elec_pos[:,0], elec_pos[:,1]) overlay_sulci_on_spatial_map('beignet', 'LM1', 'ECoG244') .. image:: _images/overlay_sulci_beignet.png .. code-block:: python plot_spatial_map(np.arange(16*16).reshape((16,16)), elec_pos[:,0], elec_pos[:,1]) overlay_sulci_on_spatial_map('affi', 'LM1', 'ECoG244', theta=90) .. image:: _images/overlay_sulci_affi.png ''' config_dir = files('aopy').joinpath('config') filename = f'{subject.lower()}_{chamber.lower()}_{drive_type.lower()}_sulci.png' params_file = as_file(config_dir.joinpath(filename)) with params_file as f: overlay_image_on_spatial_map(f, drive_type, theta=theta, center=center, alpha=alpha, **kwargs)
[docs]def plot_annotated_spatial_drive_map_stim(data, stim_site, subject, chamber, theta, elec_data=True, interp=True, grid_size=(16,16), cmap='viridis', clim=None, colorbar=True, annotation_style='marker', fontsize=8, marker='D', markersize=1.0, color='w', recording_drive_type='ECoG244', stim_drive_type='Opto32', ax=None, **kwargs): ''' Stimulation-specific version of :func:`plot_spatial_drive_map` that includes annotations for the stimulation site, removes tick marks, despines the map, and adds an overlay of the stimulation channel and chamber sulci locations. Args: data (nch): data to plot stim_site (int): stimulation site to annotate on the map subject (str): subject name chamber (str): chamber type (e.g. 'LM1') theta (int): rotation of the chamber in degrees elec_data (bool, optional): whether to treat data as per electrode (True) or per acquistion channel (False). Default True. interp (bool, optional): whether to interpolate the data onto a grid. See :func:`~aopy.visualization.calc_data_map` for options. Defaults to True. grid_size (tuple, optional): size of the grid to interpolate. Default (16, 16) cmap (str, optional): colormap to use for plotting. Default 'viridis'. clim (tuple, optional): 2-tuple of color limits (min, max) for the plot. Default None. colorbar (bool, optional): whether to add a colorbar to the plot. Default True annotation_style (str, optional): style of annotation to use for stimulation site ['text', 'marker']. Default 'marker'. fontsize (int, optional): the fontsize to make the text or marker. Defaults to 8. marker (str, optional): marker style for annotations if annotation_style is 'marker'. Options are the same as pyplot.markers.MarkerStyle; e.g. 'o', 's', etc. Default 'D'. markersize (float, optional): size of the marker in data units for annotations if annotation_style is 'marker'. Default 0.5. color (str, optional): color for annotations. Default 'w' recording_drive_type (str, optional): drive type of the recording. Default 'ECoG244'. See :func:`aopy.data.load_chmap` for options. stim_drive_type (str, optional): drive type of the stimulation. Default 'Opto32'. See :func:`aopy.data.load_chmap` for options. ax (pyplot.Axes, optional): axes on which to plot. Default current axis. kwargs (dict, optional): other keyword arguments to pass to :func:`plot_spatial_drive_map` Returns: tuple: tuple containing: - im (pyplot.Image): image object returned from pyplot.imshow. Useful for adding colorbars, etc. - pcm (pyplot.Colorbar): colorbar object if colorbar is True, otherwise None Examples: .. code-block:: python data = np.random.normal(0, 1, (240,)) stim_site = 7 plot_annotated_spatial_drive_map_stim(data, stim_site, 'beignet', 'lm1', 0, interp_method='cubic') .. image:: _images/annotated_spatial_drive_map_stim.png ''' if ax is None: ax = plt.gca() # Plot the spatial drive map im = plot_spatial_drive_map(data, elec_data=elec_data, interp=interp, grid_size=grid_size, drive_type=recording_drive_type, cmap=cmap, theta=theta, ax=ax, **kwargs) if clim is not None: im.set_clim(*clim) sns.despine(left=True, bottom=True, ax=ax) pcm = None if colorbar: pcm = plt.colorbar(im, ax=ax) ax.set(xticks=[], yticks=[], xticklabels=[], yticklabels=[], xlabel='', ylabel='') # Add annotations annotate_spatial_map_channels(acq_ch=[stim_site], fontsize=fontsize, color=color, annotation_style=annotation_style, marker=marker, markersize=markersize, drive_type=stim_drive_type, theta=theta, ax=ax) overlay_sulci_on_spatial_map(subject, chamber, recording_drive_type, theta=theta, color=color, ax=ax) return im, pcm
[docs]def plot_annotated_stim_drive_data(data, subject, chamber, theta, interp=False, stim_drive_type='Opto32', recording_drive_type='ECoG244', cmap='Blues', colorbar=True, color='k', nan_color='white', ax=None, **kwargs): ''' Plot a spatial map of data for each stimulation site in a drive with the bounds and sulci overlayed of a recording drive shown for reference. Args: data (nch): data to plot subject (str): subject name chamber (str): chamber type (e.g. 'LM1') theta (int): rotation of the chamber in degrees interp (bool, optional): whether to interpolate the data onto a grid. See :func:`~aopy.visualization.calc_data_map` for options. Defaults to False. stim_drive_type (str): drive type of the stimulation data. Default 'Opto32'. See :func:`aopy.data.load_chmap` for options. recording_drive_type (str): drive type of the recording used in the chamber. Default 'ECoG244'. See :func:`aopy.data.load_chmap` for options. cmap (str): colormap to use for plotting colorbar (bool, optional): whether to add a colorbar to the plot. Default True color (str): color for annotations. Default 'k' nan_color (str): color to use for NaN values. Default 'white' ax (pyplot.Axes, optional): axes on which to plot. Default current axis. kwargs (dict, optional): other keyword arguments to pass to :func:`plot_spatial_drive_map` Returns: tuple: tuple containing: - im (pyplot.Image): image object returned from pyplot.imshow. Useful for adding colorbars, etc. - pcm (pyplot.Colorbar): colorbar object if colorbar is True, otherwise None Examples: .. code-block:: python data = np.random.normal(0, 1, (32,)) plot_annotated_stim_drive_data(data, 'beignet', 'lm1', 0) .. image:: _images/annotated_stim_drive_data.png ''' if ax is None: ax = plt.gca() im = plot_spatial_drive_map(data, elec_data=True, drive_type=stim_drive_type, interp=interp, cmap=cmap, theta=theta, nan_color=nan_color, ax=ax, **kwargs) pcm = None if colorbar: pcm = plt.colorbar(im, ax=ax) overlay_sulci_on_spatial_map(subject, chamber, recording_drive_type, theta=theta, color=color, ax=ax) return im, pcm
[docs]def plot_plane(plane, gain=1.0, color='grey', alpha=0.15, resolution=100, ax=None, **kwargs): """ Plots a 3D plane centered at the origin. Args: plane (4-tuple or (3,3) or (4,4) matrix): Specifies how the plane is transformed: - If shape (3,3) or (4,4): Treated as a transformation matrix for rotating the plane z=0. - If shape (4,): Treated as plane equation coefficients (A, B, C, D) for Ax + By + Cz + D = 0. gain (float, optional): Scaling factor for the plane's size. Default is 1.0. Recommend using exp_gain from metadata. color (str, optional): Color of the plane. Default is 'grey'. alpha (float, optional): Transparency of the plane, where 1 is opaque and 0 is fully transparent. Default is 0.15. resolution (int, optional): Number of subdivisions for the plane. Higher values increase smoothness. Default is 100. ax (mpl_toolkits.mplot3d.Axes3D): The Matplotlib 3D axis on which to plot the plane. Raises: ValueError: If 'plane' does not have a valid shape (expected (3,3), (4,4), or (4,)). Note: - When 'plane' is a transformation matrix, only the upper-left (3,3) submatrix is used. - When 'plane' is a plane equation (A, B, C, D), the function solves for z using z = (-A * x - B * y - D) / C. Examples: .. code-block:: python import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import numpy as np fig = plt.figure() ax = fig.add_subplot(111, projection='3d') # Example using a transformation matrix (identity) plane = np.eye(3) plot_plane(plane, gain=1.0, color='blue', alpha=0.3, ax=ax) # Example using a plane equation Ax + By + Cz + D = 0 plane_eq = np.array([1, 2, -1, 5]) # x + 2y - z + 5 = 0 plot_plane(plane_eq, gain=1.0, color='red', alpha=0.5, ax=ax) plt.show() .. image:: _images/plot_plane_example.png """ if ax is None: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') xy_range = np.linspace(-10*gain, 10*gain, resolution) x, y = np.meshgrid(xy_range, xy_range) # If plane is described as a transformation matrix: if plane.shape in [(3,3),(4,4)]: coords = np.column_stack((x.ravel(), y.ravel(), np.zeros_like(x.ravel()))) rotated_coords = coords @ plane[:3, :3] x, y, z = rotated_coords.T.reshape(3, *x.shape) # If plane is described as an equation: elif plane.shape == (4,): A,B,C,D = plane z = (-A * x - B * y - D) / C else: raise ValueError(f"Invalid mapping shape {plane.shape}. Expected (3,3) or (4,4) \ for transformation matrices, or (4,) for plane equations.") ax.plot_surface(x, y, z, alpha=alpha, color=color)
[docs]def plot_sphere(location, color='gray', radius=4, resolution=20, alpha=1, bounds=None, ax=None, **kwargs): """ Plots a 3D sphere on a specified 3D Matplotlib axis. If no axis is specified, opens a new figure with a single 3D axis. Args: location (tuple or list): Coordinates of the sphere's center, specified as (x, y, z). color (str, optional): Color of the sphere. Default is 'gray'. radius (float, optional): Radius of the sphere. Default is 4. resolution (int, optional): Number of subdivisions for the sphere's surface. Higher values result in a smoother appearance but may reduce performance. Default is 20. alpha (float, optional): Transparency of the sphere, where 1 is opaque and 0 is fully transparent. Default is 1. bounds (tuple, optional): 6-element tuple describing (-x, x, -y, y, -z, z) cursor bounds. ax (mpl_toolkits.mplot3d.Axes3D, optional): The Matplotlib 3D axis on which to plot the sphere. Examples: To plot a semi-transparent blue sphere with a radius of 1 at the origin: .. code-block:: python import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D fig = plt.figure() ax = fig.add_subplot(111, projection='3d') plot_sphere(location=(0, 1, 2), color='blue', radius=5, resolution=30, alpha=0.5, ax=ax) .. image:: _images/plot_sphere_example.png """ if ax is None: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') # Generate points in spherical coordinates phi = np.linspace(0, 2 * np.pi, resolution) # azimuthal angle theta = np.linspace(0, np.pi, resolution) # polar angle # Translate to cartesian coordinates x = radius * np.outer(np.cos(phi), np.sin(theta)) + location[0] y = radius * np.outer(np.sin(phi), np.sin(theta)) + location[1] z = radius * np.outer(np.ones(np.size(phi)), np.cos(theta)) + location[2] # Plot sphere ax.plot_surface(x, y, z, color=color, alpha=alpha, **kwargs) if bounds is not None: set_bounds(bounds, ax)
[docs]def color_targets_3D(target_locations, target_idx, colors, target_radius=1, resolution=20, alpha=1, bounds=None, ax=None, **kwargs): """ Plots multiple targets as spheres in 3D space. Args: target_locations (list of tuples or lists): List of (x, y, z) coordinates specifying the centers of the target spheres. target_idx ((ntargets,) array): array of indices for each target, used to determine color. colors (list of str or None, optional): List of colors for the targets. If not provided, all targets will default to black. Must match the number of unique targets. target_radius (float, optional): Radius of each target sphere. Default is 1. resolution (int, optional): Resolution of the spheres (passed to 'plot_sphere'). Default is 20. alpha (float, optional): Transparency of the spheres, where 1 is opaque. Default is 1. bounds (tuple, optional): 6-element tuple describing (-x, x, -y, y, -z, z) cursor bounds. ax (mpl_toolkits.mplot3d.Axes3D, optional): The Matplotlib 3D axis on which to plot the targets. Examples: To visualize three targets with different colors and sizes: .. code-block:: python import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import seaborn as sns targets = np.array([ [0., 0., 0.], [0., 10., 0.], [7.0711, 7.0711, 0.], [10., 0., 0.], [7.0711, -7.0711, 0.], [0., -10., 0.], [-7.0711, -7.0711, 0.], [-10., 0., 0.], [-7.0711, 7.0711, 0.] ]) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.set_zlim3d([-10, 10]) colors = sns.color_palette(n_colors=len(targets)) aopy.visualization.color_targets_3D(targets, target_idx=np.arange(len(targets)), target_radius=1, colors=colors, ax=ax) plt.show() .. image:: _images/plot_3D_targets.png """ if ax is None: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') if colors==None: colors = ['gray'] * len(np.unique(target_locations)) assert len(target_locations) == len(target_idx), "Target locations must be the same length as target indices." target_locations = np.array(np.array(target_locations).tolist()) target_idx = np.array(np.array(target_idx).tolist()) loc_idx = np.concatenate((np.expand_dims(target_idx, 1), target_locations), axis=1) loc_idx = np.unique(loc_idx, axis=0) assert len(colors) >= len(np.unique(target_idx)), "Not enough colors for unique target indices." for row in loc_idx: idx = row[0].astype(int) loc = row[1:] plot_sphere(loc, color=colors[idx], radius=target_radius, resolution=resolution, alpha=alpha, ax=ax) if bounds is not None: set_bounds(bounds, ax)