Source code for aopy.visualization.bmi3d

# bmi3d.py
#
# visuzalition specific to BMI3D 

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from . import base
from .. import data as aodata

[docs]def plot_decoder_weight_matrix(decoder, ax=None): """ Plot the decoder weight matrix. Compatible with Decoder objects with KFDecoder and lindecoder filters. Args: decoder (riglib.bmi.Decoder): The decoder object from BMI3D. ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None. """ if ax is None: fig, ax = plt.subplots() if hasattr(decoder.filt, 'unit_to_state'): matrix = np.array(decoder.filt.unit_to_state) elif hasattr(decoder.filt, 'C'): matrix = np.array(decoder.filt.C.T) else: raise ValueError("Decoder does not have a recognizable weight matrix attribute.") max_weight = np.max(abs(matrix)) im = ax.pcolor(matrix, cmap='bwr') ax.set_xticks(np.arange(decoder.n_units)) ax.set_xticklabels([decoder.units[iu][0] for iu in range(decoder.n_units)], rotation=45) ax.set_yticks(np.arange(decoder.n_states) + .5) ax.set_yticklabels(decoder.states) ax.set(xlabel='Readout unit', ylabel='State') plt.colorbar(im, ax=ax, label='Weight') im.set_clim(-max_weight, max_weight)
[docs]def plot_readout_map(decoder, readouts, drive_type='ECoG244', cmap='YlGnBu', ax=None): """ Plot the spatial location of readouts. Args: decoder (riglib.bmi.Decoder): The decoder object from BMI3D. readouts (list): The readout channels. drive_type (str, optional): The type of drive. See :func:`~aopy.data.load_chmap` for options. Defaults to 'ECoG244'. cmap (str, optional): The colormap to use. Defaults to 'YlGnBu'. ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None. """ if ax is None: fig, ax = plt.subplots() elec_pos, acq_ch, elecs = aodata.load_chmap(drive_type=drive_type) channels = np.nan * np.zeros(np.max(acq_ch) + 1) channels[decoder.channels - 1] = np.arange(len(decoder.channels)) base.plot_spatial_drive_map(channels, interp=False, drive_type=drive_type, cmap=cmap, ax=ax, nan_color='#FF000000') base.annotate_spatial_map_channels(drive_type=drive_type, color='k', ax=ax) base.annotate_spatial_map_channels(acq_ch=readouts, drive_type=drive_type, color='w', ax=ax)
[docs]def plot_decoder_weight_vectors(decoder, x_idx, y_idx, colors, ax=None): """ Plot decoder weight vectors. Args: decoder (riglib.bmi.Decoder): The decoder object from BMI3D. x_idx (int): The index for the x state in the decoder's weight matrix y_idx (int): The index for the y state in the decoder's weight matrix colors (list): List of colors for the vectors. ax (matplotlib.axes.Axes, optional): The axes on which to plot. Defaults to None. """ if ax is None: fig, ax = plt.subplots() if hasattr(decoder.filt, 'unit_to_state'): matrix = np.array(decoder.filt.unit_to_state) elif hasattr(decoder.filt, 'C'): matrix = np.array(decoder.filt.C.T) else: raise ValueError("Decoder does not have a recognizable weight matrix attribute.") max_len = np.max(np.linalg.norm(matrix[[x_idx,y_idx],:], axis=0)) if max_len == 0: max_len = 1. nch = matrix.shape[1] ax.quiver(np.zeros(nch), np.zeros(nch), matrix[x_idx, :], matrix[y_idx, :], width=.01, color=colors, alpha=0.5, angles='xy', scale_units='xy', scale=1.) ax.set(xlabel='Px weight', ylabel='Py weight', xlim=(-max_len,max_len), ylim=(-max_len,max_len))
[docs]def plot_decoder_summary(decoder, drive_type='ECoG244', cmap='YlGnBu'): """ Plot a summary of the decoder weight matrix, readout map, and weight vectors. Example: A KF decoder with 7 states and 16 readout channels. .. image:: _images/decoder_weights.png Args: decoder (riglib.bmi.Decoder): The decoder object from BMI3D. drive_type (str, optional): The type of drive. See :func:`~aopy.data.load_chmap` for options. Defaults to 'ECoG244'. cmap (str, optional): The colormap to use. Defaults to 'YlGnBu'. """ fig, ax = plt.subplots(2, 2, figsize=(8, 8)) ax = ax.flatten() plot_decoder_weight_matrix(decoder, ax=ax[0]) ax[0].set_title('Weight matrix') plot_readout_map(decoder, decoder.channels, drive_type=drive_type, cmap=cmap, ax=ax[1]) ax[1].set_title('Readout location') colors = sns.color_palette(cmap, n_colors=len(decoder.channels)) plot_decoder_weight_vectors(decoder, 0, 2, colors, ax=ax[2]) ax[2].set_title('Pos weight vectors') plot_decoder_weight_vectors(decoder, 3, 5, colors, ax=ax[3]) ax[3].set_title('Vel weight vectors') fig.tight_layout()