# bmi3d.py
#
# Code for parsing and preparing data from BMI3D
import warnings
import os
from datetime import datetime
from importlib.metadata import version
import json
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import sympy
from .. import precondition
from .. import data as aodata
from .. import postproc
from .. import utils
from .. import analysis
from .. import visualization
from . import base
from . import laser
[docs]def decode_event(dictionary, value):
'''
Decode a integer event code into a event name and data
Args:
dictionary (dict): dictionary of (event_name, event_code) event definitions
value (int): number to decode
Returns:
tuple: 2-tuple containing (event_name, data) for the given value
'''
# Sort the dictionary in order of value
ordered_list = sorted(dictionary.items(), key=lambda x: x[1])
# Find a matching event (greatest value that is lower than the given value)
for i, event in enumerate(ordered_list[1:]):
if value < event[1]:
event_name = ordered_list[i][0]
event_data = value - ordered_list[i][1]
return event_name, event_data
# Check last value
if value == ordered_list[-1][1]:
return ordered_list[-1][0], 0
# Return none if no matching events
return None
[docs]def decode_events(dictionary, values):
'''
Decode a list of integer event code into a event names and data
Args:
dictionary (dict): dictionary of (event_name, event_code) event definitions
values (n_values): list of integer numbers to decode
Returns:
tuple: 2-tuple containing (event_names, data) for the given values
'''
tuples = [decode_event(dictionary, value) for value in values]
return list(zip(*tuples))
def _correct_hand_traj(hand_position, cursor_position):
'''
This function removes hand position data points when the cursor is simultaneously stationary in all directions.
These hand position data points are artifacts.
Args:
hand_position (nt, 3): Uncorrected hand position
cursor_position (nt, 3): Cursor position from the same experiment, used to find where the hand position is invalid.
Returns:
hand_position (nt, 3): Corrected hand position
'''
# Set hand position to np.nan if the cursor position doesn't update. This indicates an optitrack error moved the hand outside the boundary.
bad_pt_mask = np.zeros(cursor_position.shape, dtype=bool)
bad_pt_mask[1:,0] = (np.diff(cursor_position, axis=0)==0)[:,0] & (np.diff(cursor_position, axis=0)==0)[:,1] & (np.diff(cursor_position, axis=0)==0)[:,2]
bad_pt_mask[1:,1] = (np.diff(cursor_position, axis=0)==0)[:,0] & (np.diff(cursor_position, axis=0)==0)[:,1] & (np.diff(cursor_position, axis=0)==0)[:,2]
bad_pt_mask[1:,2] = (np.diff(cursor_position, axis=0)==0)[:,0] & (np.diff(cursor_position, axis=0)==0)[:,1] & (np.diff(cursor_position, axis=0)==0)[:,2]
hand_position[bad_pt_mask] = np.nan
return hand_position
def _correct_tracking_task_data(data, metadata, contains_hand=True):
'''
This function fixes the frame shift bug in reference & disturbance trajectories saved by an older version of the BMI3d tracking task.
It uses the 'current_target_validate' field (which comes from 'target.get_position()' and is the ground truth for what was displayed on the screen)
for the reference trajectory. It calculates 'user_screen' position and adds it to the task data.
Finally, it uses difference between the bounded cursor and bounded 'user_screen' for the disturbance trajectory.
Args:
data (dict): bmi3d data
metadata (dict): bmi3d metadata
contains_hand (bool, optional): whether or not the 'manual_input' data comes from optitrack, which means it needs to be cleaned
Returns:
task (nt,): array of task data with specified dtypes
'''
bmi3d_task = data['bmi3d_task']
# list of task data fields to keep
keys = list(bmi3d_task.dtype.names)
keys.remove('current_target')
keys.remove('current_disturbance')
keys.remove('current_target_validate')
# list of task data fields to create/correct
dtypes = [(key, bmi3d_task.dtype.fields[key][0]) for key in keys]
dtypes.append(('user_screen', 'f8', (3,)))
dtypes.append(('target', 'f8', (3,)))
dtypes.append(('disturbance', 'f8', (3,)))
# construct corrected task data
task = np.zeros(len(bmi3d_task), dtype=dtypes)
for key in keys:
task[key] = bmi3d_task[key]
# transform manual_input (user_raw) to user_world to user_screen
user_raw = bmi3d_task['manual_input']
if contains_hand:
if metadata['sync_protocol_version'] < 14 and isinstance(bmi3d_task, np.ndarray) and 'manual_input' in bmi3d_task.dtype.names:
clean_hand_position = _correct_hand_traj(bmi3d_task['manual_input'], bmi3d_task['cursor'])
if np.count_nonzero(~np.isnan(clean_hand_position)) > 2*clean_hand_position.ndim:
user_raw = clean_hand_position
if 'exp_gain' in metadata:
scale = metadata['scale']
exp_gain = metadata['exp_gain']
else:
scale = np.sign(metadata['scale'])
exp_gain = np.abs(metadata['scale'])
user_world = postproc.bmi3d.convert_raw_to_world_coords(user_raw, metadata['rotation'], metadata['offset'], scale) # intuitive world coords (x: right/left, y: up/down, z: forward/backward)
if 'baseline_rotation' in metadata:
baseline_rotation = metadata['baseline_rotation']
else:
baseline_rotation = 'none'
if 'exp_rotation' in metadata:
exp_rotation = metadata['exp_rotation']
else:
exp_rotation = 'none'
# fixed perturbations
if 'perturbation_rotation_x' in metadata:
x_rot = metadata['perturbation_rotation_x']
z_rot = metadata['perturbation_rotation_z']
else:
x_rot = 0
z_rot = 0
if 'pertubation_rotation' in metadata:
y_rot = metadata['pertubation_rotation']
else:
y_rot = 0
exp_mapping = postproc.bmi3d.get_world_to_screen_mapping(exp_rotation, x_rot, y_rot, z_rot, exp_gain, baseline_rotation)
user_screen = np.dot(user_world, exp_mapping) # intuitive screen coords (x: right/left, y: up/down, z: into/out of the screen)
# incremental perturbations
if 'incremental_rotation' in metadata['features']:
x_fixed = metadata['init_rotation_x']==metadata['final_rotation_x']
y_fixed = metadata['init_rotation_y']==metadata['final_rotation_y']
z_fixed = metadata['init_rotation_z']==metadata['final_rotation_z']
if y_fixed and z_fixed and not x_fixed:
start_deg = metadata['init_rotation_x']
end_deg = metadata['final_rotation_x']
delta_deg = metadata['delta_rotation_x']
elif x_fixed and z_fixed and not y_fixed:
start_deg = metadata['init_rotation_y']
end_deg = metadata['final_rotation_y']
delta_deg = metadata['delta_rotation_y']
elif x_fixed and y_fixed and not z_fixed:
start_deg = metadata['init_rotation_z']
end_deg = metadata['final_rotation_z']
delta_deg = metadata['delta_rotation_z']
trials_per_inc = metadata['trials_per_increment']
n_inc = int((end_deg-start_deg)/delta_deg+1)
rotations = ( np.tile(np.linspace(start_deg, end_deg, n_inc), (trials_per_inc,1)) ).flatten('F') # column-major order
states = data['bmi3d_state']['msg'] # bmi3d state
state_cycles = data['bmi3d_state']['time'] # bmi3d cycle number
post_reward_cycles = state_cycles[np.where(states==b'reward')[0]+1] # bmi3d cycle number of the wait state following each reward
reward_rotations = []
if len(rotations) > len(post_reward_cycles):
remainder = len(rotations) - len(post_reward_cycles)
rotations = np.delete(rotations, np.s_[-remainder:])
elif len(rotations) < len(post_reward_cycles):
remainder = len(post_reward_cycles) - len(rotations)
rotations = np.concatenate((rotations, np.ones((remainder))*rotations[-1]))
reward_rotations.extend(rotations)
# find bmi3d cycle number where mapping changes
change_idx = abs(np.diff(reward_rotations,append=np.nan))>0
change_cycles = post_reward_cycles[change_idx]
change_cycles = np.hstack((0, change_cycles, len(user_world)-1))
# transform user_world to user_screen in increments
user_screen = np.zeros(user_world.shape)
for i,idx in enumerate(change_cycles[:-1]):
if idx==0:
x_rot = metadata['init_rotation_x']
y_rot = metadata['init_rotation_y']
z_rot = metadata['init_rotation_z']
else:
if y_fixed and z_fixed and not x_fixed:
x_rot += delta_deg
elif x_fixed and z_fixed and not y_fixed:
y_rot += delta_deg
elif x_fixed and y_fixed and not z_fixed:
z_rot += delta_deg
exp_mapping = postproc.bmi3d.get_world_to_screen_mapping(exp_rotation, x_rot, y_rot, z_rot, exp_gain, baseline_rotation)
user_screen[change_cycles[i]:change_cycles[i+1]] = np.dot(user_world[change_cycles[i]:change_cycles[i+1]], exp_mapping)
task['user_screen'] = user_screen[:,[0,2,1]] # reorder to match bmi3d coords (x: right/left, y: into/out of the screen, z: up/down)
task['target'] = bmi3d_task['current_target_validate'] # this comes from target.get_position() and reflects actual reference
# clip cursor and user_screen to screen bounds to calculate actual disturbance
bounds = np.hstack((np.zeros(4), metadata['cursor_bounds'][-2:])) # disturbance is 0 along bmi3d x axis (right/left) and y axis (into/out of the screen)
cursor_bounded = np.array([np.clip(task['cursor'][:,i], bounds[i*2], bounds[i*2+1]) for i in range(3)]).T
user_bounded = np.array([np.clip(task['user_screen'][:,i], bounds[i*2], bounds[i*2+1]) for i in range(3)]).T
task['disturbance'] = cursor_bounded - user_bounded # only bmi3d z axis has non-zero values but all axes retain existing NaN values from user_screen
print('...correcting tracking task frame shift')
return task
def _correct_touch_app_data(original_task):
'''
This function replaces NaN values in cursor data saved by an older version of the tablet touch app (where the cursor would disappear whenever
there was no touch).
Before the first touch input, cursor values remain NaN but the 'plant_visible' field is updated to False to indicate that the cursor
was hidden. After the first touch input, NaN cursor values are replaced with the last valid cursor position and 'plant_visible' is set to False.
This function does not affect the 'user_screen' field, which is NaN when there is no touch input.
Args:
original_task (nt,): original array of task data with specified dtypes
Returns:
task (nt,): corrected array of task data with specified dtypes
'''
# list of task data fields to keep
keys = list(original_task.dtype.names)
dtypes = [(key, original_task.dtype.fields[key][0]) for key in keys]
# construct corrected task data
task = np.zeros(len(original_task), dtype=dtypes)
for key in keys:
task[key] = original_task[key]
# indicate that cursor was hidden when NaN
original_cursor = pd.DataFrame(original_task['cursor'])
nan_idx = np.where(np.isnan(original_cursor).any(axis=1))[0]
task['plant_visible'][nan_idx] = False
# replace NaNs in cursor with last valid value
corrected_cursor = original_cursor.ffill()
task['cursor'] = np.array(corrected_cursor)
print('...correcting touch app data')
return task
[docs]def parse_bmi3d(data_dir, files):
'''
Wrapper around version-specific bmi3d parsers
Args:
data_dir (str): where to look for the data
files (dict): dictionary of files for this experiment
Returns:
tuple: tuple containing:
| **data (dict):** bmi3d data
| **metadata (dict):** bmi3d metadata
'''
# Load bmi3d data to see which sync protocol is used
try:
events, event_metadata = aodata.load_bmi3d_hdf_table(data_dir, files['hdf'], 'sync_events')
sync_version = event_metadata['sync_protocol_version']
except:
sync_version = -1
# Pass files onto the appropriate parser
if sync_version <= 0:
data, metadata = _parse_bmi3d_v0(data_dir, files)
metadata['bmi3d_parser'] = 0
metadata['sync_protocol_version'] = sync_version
elif sync_version < 16:
data, metadata = _parse_bmi3d_v1(data_dir, files)
metadata['bmi3d_parser'] = 1
else:
print("Warning: this bmi3d sync version is untested!")
data, metadata = _parse_bmi3d_v1(data_dir, files)
metadata['bmi3d_parser'] = 1
# Keep track of the software version
metadata['bmi3d_preproc_date'] = datetime.now()
try:
metadata['bmi3d_preproc_version'] = version('aolab-aopy')
except:
metadata['bmi3d_preproc_version'] = 'unknown'
# And where the data came from
try:
metadata['bmi3d_source'] = os.path.join(data_dir, files['hdf'])
except:
metadata['bmi3d_source'] = None
# Standardize the parsed variable names and perform some error checking
if metadata['bmi3d_parser'] == 0:
prepared_data, prepared_metadata = _prepare_bmi3d_v0(data, metadata)
elif metadata['bmi3d_parser'] == 1:
prepared_data, prepared_metadata = _prepare_bmi3d_v1(data, metadata)
return prepared_data, prepared_metadata
def _parse_bmi3d_v0(data_dir, files):
'''
Simple parser for BMI3D data. Ignores external data.
Args:
data_dir (str): where to look for the data
files (dict): dictionary of files for this experiment
Returns:
tuple: tuple containing:
| **data (dict):** bmi3d data
| **metadata (dict):** bmi3d metadata
'''
bmi3d_data = {}
metadata = {}
metadata['source_dir'] = data_dir
metadata['source_files'] = files
if 'hdf' not in files:
warnings.warn("No hdf file found, cannot parse bmi3d data")
metadata['preproc_errors'] = ["No hdf file found, cannot parse bmi3d data"]
return bmi3d_data, metadata
bmi3d_hdf_filename = files['hdf']
bmi3d_hdf_full_filename = os.path.join(data_dir, bmi3d_hdf_filename)
# Load data and metadata from hdf file
bmi3d_root_metadata = aodata.load_bmi3d_root_metadata(data_dir, bmi3d_hdf_filename)
metadata.update(bmi3d_root_metadata)
if 'features' not in bmi3d_root_metadata:
metadata['features'] = []
if aodata.is_table_in_hdf('task', bmi3d_hdf_full_filename):
bmi3d_task, bmi3d_task_metadata = aodata.load_bmi3d_hdf_table(data_dir, bmi3d_hdf_filename, 'task')
metadata.update(bmi3d_task_metadata)
if len(bmi3d_task) > 0 and isinstance(bmi3d_task, np.ndarray):
bmi3d_data['bmi3d_task'] = bmi3d_task
if aodata.is_table_in_hdf('sync_events', bmi3d_hdf_full_filename):
bmi3d_events, bmi3d_event_metadata = aodata.load_bmi3d_hdf_table(data_dir, bmi3d_hdf_filename, 'sync_events') # exists in tablet data
metadata.update(bmi3d_event_metadata)
if len(bmi3d_events) > 0 and isinstance(bmi3d_events, np.ndarray):
bmi3d_data['bmi3d_events'] = bmi3d_events
if aodata.is_table_in_hdf('clda', bmi3d_hdf_full_filename):
bmi3d_clda, bmi3d_clda_meta = aodata.load_bmi3d_hdf_table(data_dir, bmi3d_hdf_filename, 'clda')
metadata.update(bmi3d_clda_meta)
bmi3d_data['bmi3d_clda'] = bmi3d_clda
if aodata.is_table_in_hdf('task_msgs', bmi3d_hdf_full_filename):
bmi3d_state, _ = aodata.load_bmi3d_hdf_table(data_dir, bmi3d_hdf_filename, 'task_msgs')
bmi3d_data['bmi3d_state'] = bmi3d_state
if aodata.is_table_in_hdf('trials', bmi3d_hdf_full_filename):
bmi3d_trials, _ = aodata.load_bmi3d_hdf_table(data_dir, bmi3d_hdf_filename, 'trials')
bmi3d_data['bmi3d_trials'] = bmi3d_trials
if aodata.is_table_in_hdf('sync_clock', bmi3d_hdf_full_filename): # exists but empty in tablet data
clock, _ = aodata.load_bmi3d_hdf_table(data_dir, bmi3d_hdf_filename, 'sync_clock') # there isn't any clock metadata
if len(clock) > 0 and isinstance(clock, np.ndarray):
bmi3d_data['bmi3d_clock'] = clock
return bmi3d_data, metadata
def _parse_bmi3d_v1(data_dir, files):
'''
Parser for BMI3D data which incorporates external sync data. Only compatible with sync versions > 0
Args:
data_dir (str): where to look for the data
files (dict): dictionary of files for this experiment
Returns:
tuple: tuple containing:
| **data_dict (dict):** bmi3d data
| **metadata_dict (dict):** bmi3d metadata
'''
# Start by loading bmi3d data using the v0 parser
data_dict, metadata_dict = _parse_bmi3d_v0(data_dir, files)
# Parse digital data
digital_data = None
if 'digital' in files:
digital_data = aodata.load_hdf_data('', files['digital'], 'digital_data')
digital_metadata = aodata.load_hdf_group('', files['digital'], 'digital_metadata')
elif 'ecube' in files:
digital_data, digital_metadata = aodata.load_ecube_digital(data_dir, files['ecube'])
elif 'emg' in files:
digital_data, digital_metadata = aodata.load_emg_digital(data_dir, files['emg'])
if digital_data is not None: # sync_events and sync_clock
digital_samplerate = digital_metadata['samplerate']
# Sync clock
clock_sync_bit_mask = None
if metadata_dict['sync_protocol_version'] < 3:
clock_sync_bit_mask = 0x1000000 # wrong in 1 and 2
elif 'screen_sync_dch' in metadata_dict:
clock_sync_bit_mask = utils.convert_channels_to_mask(metadata_dict['screen_sync_dch'])
if clock_sync_bit_mask is not None:
clock_sync_data = utils.extract_bits(digital_data, clock_sync_bit_mask)
clock_sync_timestamps, _ = utils.detect_edges(clock_sync_data, digital_samplerate, rising=True, falling=False)
sync_clock = np.zeros((len(clock_sync_timestamps),), dtype=[('timestamp', 'f8')])
sync_clock['timestamp'] = clock_sync_timestamps
if len(sync_clock) > 0:
data_dict['sync_clock'] = sync_clock
# Mask and detect BMI3D computer events from ecube
if 'event_sync_dch' in metadata_dict and metadata_dict['event_sync_dch'] is not None:
event_bit_mask = utils.convert_channels_to_mask(metadata_dict['event_sync_dch']) # 0xff0000
ecube_sync_data = utils.extract_bits(digital_data, event_bit_mask)
ecube_sync_timestamps, ecube_sync_events = utils.detect_edges(ecube_sync_data, digital_samplerate,
rising=True, falling=False)
sync_event_names, sync_event_data = decode_events(metadata_dict['event_sync_dict'], ecube_sync_events)
sync_events = np.zeros((len(ecube_sync_timestamps),), dtype=[('time', 'u8'), ('timestamp', 'f8'), ('code', 'u1'), ('event', 'S32'), ('data', 'u4')])
search_radius = 1.5/metadata_dict['fps']
sync_timestamps, ecube_sync_cycles = base.find_measured_event_times(ecube_sync_timestamps, clock_sync_timestamps, search_radius, return_idx=True)
ecube_sync_cycles[0] = 0 # first event is always TIME_ZERO
data_dict['sync_events_timestamp_error'] = ecube_sync_timestamps - sync_timestamps
sync_events['time'] = ecube_sync_cycles
sync_events['timestamp'] = ecube_sync_timestamps
sync_events['code'] = ecube_sync_events
sync_events['event'] = sync_event_names
sync_events['data'] = sync_event_data
# Screen measurements digitized online
if 'screen_measure_dch' in metadata_dict:
if len(sync_events) > 0:
data_dict['sync_events'] = sync_events
measure_clock_online = base.get_dch_data(digital_data, digital_samplerate, metadata_dict['screen_measure_dch'])
if len(measure_clock_online) > 0:
data_dict['measure_clock_online'] = measure_clock_online
# Laser trigger
lasers = aodata.bmi3d.load_bmi3d_lasers()
possible_dch = [laser['trigger_dch'] for laser in lasers]
for dch in possible_dch:
if dch in metadata_dict:
trigger_name = dch[:-4]
data_dict[trigger_name] = base.get_dch_data(digital_data, digital_samplerate, metadata_dict[dch])
# Optical switch
if 'qwalor_switch_rdy_dch' in metadata_dict:
switch_rdy_mask = utils.convert_channels_to_mask(metadata_dict['qwalor_switch_rdy_dch'])
ecube_switch_moving = utils.extract_bits(digital_data, switch_rdy_mask)
switch_bit_mask = utils.convert_channels_to_mask(metadata_dict['qwalor_switch_data_dch']) # 0xff0000
ecube_switch_data = utils.extract_bits(digital_data, switch_bit_mask) + 1 # change to 1-index
masked_ecube_switch_data = ecube_switch_data.copy()
masked_ecube_switch_data[ecube_switch_moving == 1] = 0 # mask data when the switch isn't ready
ecube_switch_timestamps, ecube_switch_channel = utils.detect_edges(masked_ecube_switch_data, digital_samplerate,
rising=True, falling=True, check_alternating=False)
optical_switch = np.zeros((len(ecube_switch_timestamps),), dtype=[('timestamp', 'f8'), ('channel', 'u1')])
optical_switch['timestamp'] = ecube_switch_timestamps
optical_switch['channel'] = ecube_switch_channel # 1-indexed; positive is rising edge, zero is falling edge
data_dict['optical_switch'] = optical_switch
metadata_dict['digital_samplerate'] = digital_samplerate
#Parse analog data
analog_data = None
if 'analog' in files:
analog_data = aodata.load_hdf_data('', files['analog'], 'analog_data')
analog_metadata = aodata.load_hdf_group('', files['analog'], 'analog_metadata')
elif 'ecube' in files:
analog_data, analog_metadata = aodata.load_ecube_analog(data_dir, files['ecube'])
if analog_data is not None:
analog_samplerate = analog_metadata['samplerate']
# Mask and detect screen sensor events (A5 and D5)
if 'screen_measure_ach' in metadata_dict:
clock_measure_analog = analog_data[:, metadata_dict['screen_measure_ach']] # 5
clock_measure_digitized = utils.convert_analog_to_digital(clock_measure_analog, thresh=0.5)
measure_clock_offline = base.get_dch_data(clock_measure_digitized, analog_samplerate, 0)
if len(measure_clock_offline) > 0:
data_dict['measure_clock_offline'] = measure_clock_offline
# And reward system (A0)
if 'reward_measure_ach' in metadata_dict:
reward_system_analog = analog_data[:, metadata_dict['reward_measure_ach']] # 0
reward_system_digitized = utils.convert_analog_to_digital(reward_system_analog)
reward_system_timestamps, reward_system_values = utils.detect_edges(reward_system_digitized, analog_samplerate, rising=True, falling=True)
reward_system = np.zeros((len(reward_system_timestamps),), dtype=[('timestamp', 'f8'), ('state', '?')])
reward_system['timestamp'] = reward_system_timestamps
reward_system['state'] = reward_system_values
if len(reward_system) > 0:
data_dict['reward_system'] = reward_system
# Analog cursor out (A3, A4) since version 11
if 'cursor_x_ach' in metadata_dict and 'cursor_z_ach' in metadata_dict:
cursor_analog = analog_data[:, [metadata_dict['cursor_x_ach'], metadata_dict['cursor_z_ach']]]
cursor_analog, _ = precondition.filter_kinematics(cursor_analog, samplerate=analog_samplerate)
cursor_analog_samplerate = 1000
cursor_analog = precondition.downsample(cursor_analog, analog_samplerate, cursor_analog_samplerate)
max_voltage = 3.34 # using teensy 3.6
cursor_analog_cm = ((cursor_analog * analog_metadata['voltsperbit']) - max_voltage/2) / metadata_dict['cursor_out_gain']
data_dict.update({
'cursor_analog_volts': cursor_analog,
'cursor_analog_cm': cursor_analog_cm,
})
metadata_dict['cursor_analog_samplerate'] = cursor_analog_samplerate
metadata_dict.update({
'analog_samplerate': analog_samplerate,
'analog_voltsperbit': analog_metadata['voltsperbit']
})
# Laser sensors
lasers = aodata.bmi3d.load_bmi3d_lasers()
possible_ach = [laser['sensor_ach'] for laser in lasers]
sensor_names = [laser['sensor'] for laser in lasers]
for ach, name in zip(possible_ach, sensor_names):
if ach in metadata_dict:
laser_sensor_data = analog_data[:, metadata_dict[ach]]
data_dict[name] = laser_sensor_data
return data_dict, metadata_dict
def _prepare_bmi3d_v0(data, metadata):
'''
Organizes the bmi3d data and metadata for experiments with no external data sources.
Args:
data (dict): bmi3d data
metadata (dict): bmi3d metadata
Returns:
tuple: tuple containing:
| **data (dict):** prepared bmi3d data
| **metadata (dict):** prepared bmi3d metadata
'''
preproc_errors = []
if metadata['sync_protocol_version'] > 0:
warnings.warn("This version of the parser should only be used without any external data sources (e.g. tablet)")
preproc_errors.append(f"Incompatible parser for sync protocol version {metadata['sync_protocol_version']}")
if 'bmi3d_clock' in data:
data['clock'] = data['bmi3d_clock']
names = list(data['clock'].dtype.names)
names[np.where([n == 'timestamp' for n in names])[0][0]] = 'timestamp_bmi3d'
data['clock'].dtype.names = names
elif 'bmi3d_task' in data:
warnings.warn("No clock data found, timing will be inaccurate. Enable HDF sync to fix this")
preproc_errors.append("No clock data found! Used task data instead. Enable HDF sync to fix this.")
# Estimate timestamps
bmi3d_cycles = np.arange(len(data['bmi3d_task']))
bmi3d_timestamps = bmi3d_cycles/metadata['fps']
bmi3d_clock = np.zeros((len(data['bmi3d_task']),), dtype=[('time', 'u8'), ('timestamp_bmi3d', 'f8')])
bmi3d_clock['time'] = bmi3d_cycles
bmi3d_clock['timestamp_bmi3d'] = bmi3d_timestamps
data['clock'] = bmi3d_clock
else:
warnings.warn("No clock or task data found! Cannot accurately prepare bmi3d data")
preproc_errors.append("No task data found! Cannot accurately prepare bmi3d data")
data['task'] = np.zeros((0,), dtype=[('time', 'f8'), ('cursor', 'f8', (3,))])
data['clock'] = np.zeros((0,), dtype=[('time', 'f8'), ('timestamp_bmi3d', 'f8')])
return data, metadata
if 'bmi3d_events' in data:
corrected_events = np.zeros((len(data['bmi3d_events']),),
dtype=[('time', 'u8'), ('timestamp', 'f8'), ('timestamp_bmi3d', 'f8'),
('code', 'u1'), ('event', 'S32'), ('data', 'u4')])
corrected_events['time'] = data['bmi3d_events']['time']
corrected_events['timestamp_bmi3d'] = np.asarray([data['clock']['timestamp_bmi3d'][cycle] for cycle in data['bmi3d_events']['time']])
corrected_events['timestamp'] = corrected_events['timestamp_bmi3d']
corrected_events['code'] = data['bmi3d_events']['code']
corrected_events['event'] = data['bmi3d_events']['event']
if 'data' in data['bmi3d_events'].dtype.names:
corrected_events['data'] = data['bmi3d_events']['data']
data['events'] = corrected_events
else:
warnings.warn("No event data found! Cannot accurately prepare bmi3d data")
preproc_errors.append("No event data found! Cannot accurately prepare bmi3d data")
data['events'] = np.zeros((0,), dtype=[('time', 'u8'), ('timestamp', 'f8'), ('code', 'u1'), ('event', 'S32'), ('data', 'u4')])
# Add task data
if 'bmi3d_task' in data:
task = data['bmi3d_task']
# special handling for an old version of the tracking task
if 'generator' in metadata and metadata['generator']=='tracking_target_chain':
if 'current_target' in task.dtype.names:
# task data from bmi3d has bugs in saved reference and disturbance
task = _correct_tracking_task_data(data, metadata, contains_hand=False)
# special handling for an early version of tablet touch app
if 'tablet_touch' in metadata['features']:
task = _correct_touch_app_data(task)
data['task'] = task
else:
warnings.warn("No task data found! Cannot accurately prepare bmi3d data")
preproc_errors.append("No task data found! Cannot accurately prepare bmi3d data")
data['task'] = np.zeros((0,), dtype=[('time', 'f8'), ('cursor', 'f8', (3,))])
# Compare task and clock data
if abs(len(data['task']) - len(data['clock'])) == 1:
data['clock'] = data['clock'][:len(data['task'])]
data['task'] = data['task'][:len(data['clock'])]
elif len(data['task']) != len(data['clock']):
warnings.warn(f"Number of task cycles ({len(data['task'])}) doesn't match number of clock cycles ({len(data['clock'])}).")
preproc_errors.append(f"Number of task cycles ({len(data['task'])}) doesn't match number of clock cycles ({len(data['clock'])}).")
if isinstance(data['task'], np.ndarray) and 'manual_input' in data['task'].dtype.names:
data['clean_hand_position'] = data['task']['manual_input']
metadata['preproc_errors'] = preproc_errors
return data, metadata
def _prepare_bmi3d_v1(data, metadata):
'''
Organizes the bmi3d data and metadata and computes some automatic conversions. Corrects for
unreliable sync clock signal, finds measured timestamps, and pads the clock for versions
with a sync period at the beginning of the experiment. Should be applied to data with
sync protocol version > 0. Prioritizes bmi3d events over sync events.
Args:
data (dict): bmi3d data
metadata (dict): bmi3d metadata
Returns:
tuple: tuple containing:
| **data (dict):** prepared bmi3d data
| **metadata (dict):** prepared bmi3d metadata
'''
preproc_errors = []
if not metadata['sync_protocol_version'] > 0:
warnings.warn("This version of the parser is only compatible with sync protocol version > 0")
preproc_errors.append(f"Incompatible parser for sync protocol version {metadata['sync_protocol_version']}")
# Estimate display latency
if 'sync_clock' in data and 'measure_clock_offline' in data:
# Estimate the latency based on the "sync" state at the beginning of the experiment
sync_impulse = data['sync_clock']['timestamp'][1:3]
measure_impulse = base.get_measured_clock_timestamps(sync_impulse, data['measure_clock_offline']['timestamp'],
latency_estimate=0.01, search_radius=0.1)
if np.count_nonzero(np.isnan(measure_impulse)) > 0:
warnings.warn("Warning: sync failed. Using latency estimate 0.01")
measure_latency_estimate = 0.01
preproc_errors.append(f"Screen sync failed. Guessing {measure_latency_estimate} s latency")
else:
measure_latency_estimate = np.mean(measure_impulse - sync_impulse)
print("Sync latency estimate: {:.4f} s".format(measure_latency_estimate))
else:
measure_latency_estimate = 0.01 # Guess 10 ms
metadata['measure_latency_estimate'] = measure_latency_estimate
# Use the sync clock if it exists
corrected_clock = {}
if 'sync_clock' in data:
corrected_clock['time'] = np.arange(len(data['sync_clock']))
corrected_clock['timestamp_sync'] = data['sync_clock']['timestamp']
approx_clock = data['sync_clock']['timestamp'].copy()
# Make sure the sync clock is the same length as the internal clock
sync_search_radius = 1.5/metadata['fps']
valid_clock_cycles = len(corrected_clock)
if 'bmi3d_clock' in data and 'sync_clock' in data:
sync_clock = data['sync_clock']
internal_clock = data['bmi3d_clock']
corrected_clock['time'] = internal_clock['time'].copy()
approx_clock = internal_clock['timestamp'].copy() - internal_clock['timestamp'][0] + sync_clock['timestamp'][0]
corrected_clock['timestamp_bmi3d'] = approx_clock
# Fill in the holes if necessary
if len(sync_clock) < len(internal_clock):
warnings.warn(f"Warning: length of clock timestamps on eCube ({len(sync_clock)}) doesn't match bmi3d record ({len(internal_clock)})")
valid_clock_cycles = len(sync_clock)
# Adjust the internal clock so that it starts at the same time as the sync clock
approx_clock = approx_clock + sync_clock['timestamp'][0] - approx_clock[0]
# Find sync clock pulses that match up to the expected internal clock timestamps within 1 radius
timestamp_sync = base.get_measured_clock_timestamps(
approx_clock, sync_clock['timestamp'], 0, sync_search_radius) # assume no latency between bmi3d and ecube via nidaq
nanmask = np.isnan(timestamp_sync)
timestamp_sync[nanmask] = approx_clock[nanmask] # if nothing, then use the approximated value
# Use the bmi3d clock cycles and the new estimated sync timestamps
corrected_clock['timestamp_sync'] = timestamp_sync
preproc_errors.append(f"Only {len(sync_clock)} out of {len(internal_clock)} clock cycles were recorded")
if len(sync_clock) > len(internal_clock):
warnings.warn(f"Extra clock cycles were recorded ({len(sync_clock)} out of {len(internal_clock)})")
preproc_errors.append(f"Extra clock cycles were recorded ({len(sync_clock)} out of {len(internal_clock)})")
# Otherwise fall back on the bmi3d clock
elif 'bmi3d_clock' in data:
corrected_clock['time'] = data['bmi3d_clock']['time'].copy()
corrected_clock['timestamp_bmi3d'] = data['bmi3d_clock']['timestamp'].copy()
approx_clock = data['bmi3d_clock']['timestamp'].copy()
warnings.warn("No sync clock present even though external data is available!")
preproc_errors.append("No sync clock present even though external data is available!")
elif 'sync_clock' not in data:
warnings.warn("No clock data found! Cannot accurately prepare bmi3d data")
preproc_errors.append("No clock data found! Cannot accurately prepare bmi3d data")
data['clock'] = np.zeros((0,), dtype=[('time', 'f8'), ('timestamp_bmi3d', 'f8')])
return data, metadata
# Estimate screen timing with photodiode measurements
measure_search_radius = 1.5/metadata['fps']
max_consecutive_missing_cycles = metadata['fps'] # maximum 1 second missing
if 'measure_clock_offline' in data:
timestamp_measure_offline = base.get_measured_clock_timestamps(
approx_clock, data['measure_clock_offline']['timestamp'],
measure_latency_estimate, measure_search_radius)
# Check the integrity of the measured timestamps
metadata['latency_measured'] = np.nanmean(timestamp_measure_offline - approx_clock)
metadata['n_missing_markers'] = np.count_nonzero(np.isnan(timestamp_measure_offline[:valid_clock_cycles]))
n_consecutive_missing_cycles = utils.max_repeated_nans(timestamp_measure_offline[:valid_clock_cycles])
if n_consecutive_missing_cycles < max_consecutive_missing_cycles:
metadata['has_measured_timestamps'] = True
corrected_clock['timestamp_measure_offline'] = timestamp_measure_offline
else:
warnings.warn(f"Analog screen sensor missing too many markers ({n_consecutive_missing_cycles}/{max_consecutive_missing_cycles}). Ignoring")
preproc_errors.append(f"Analog screen sensor missing too many markers ({n_consecutive_missing_cycles}/{max_consecutive_missing_cycles})")
# Assemble the corrected clock
corrected_clock = pd.DataFrame.from_dict(corrected_clock).to_records(index=False)
# Trim / pad the clock
if metadata['sync_protocol_version'] >= 3 and metadata['sync_protocol_version'] < 6:
n_cycles = int(data['bmi3d_clock']['time'][-1])
# Due to the "sync" state at the beginning of the experiment, we need
# to add some (meaningless) cycles to the beginning of the clock
state_log = data['bmi3d_state']
n_sync_cycles = state_log['time'][1] # 120, approximately
n_sync_clocks = np.count_nonzero(data['bmi3d_clock']['time'] < n_sync_cycles)
padded_clock = np.zeros((n_cycles,), dtype=corrected_clock.dtype)
padded_clock[n_sync_cycles:] = corrected_clock[n_sync_clocks:]
padded_clock['time'][:n_sync_cycles] = range(n_sync_cycles)
corrected_clock = padded_clock
# By default use the internal events if they exist
if 'bmi3d_events' in data:
corrected_events = np.zeros((len(data['bmi3d_events']),), dtype=
[('time', 'u8'), ('timestamp', 'f8'), ('timestamp_bmi3d', 'f8'),
('timestamp_sync', 'f8'), ('timestamp_measure', 'f8'),
('code', 'u1'), ('event', 'S32'), ('data', 'u4')])
corrected_events['time'] = data['bmi3d_events']['time']
corrected_events['code'] = data['bmi3d_events']['code']
corrected_events['event'] = data['bmi3d_events']['event']
try:
corrected_events['data'] = data['bmi3d_events']['data']
except:
corrected_events['data'] = 0
# Otherwise fall back on the sync events
elif 'sync_events' in data:
warnings.warn("No bmi3d event data found! Attempting to use sync events instead")
preproc_errors.append("No bmi3d sync event data found! Attempted to use sync events instead")
corrected_events = np.zeros((len(data['sync_events']),), dtype=
[('time', 'u8'), ('timestamp', 'f8'), ('timestamp_bmi3d', 'f8'),
('timestamp_sync', 'f8'), ('timestamp_measure', 'f8'),
('code', 'u1'), ('event', 'S32'), ('data', 'u4')])
corrected_events['time'] = data['sync_events']['time']
corrected_events['code'] = data['sync_events']['code']
corrected_events['event'] = data['sync_events']['event']
corrected_events['data'] = data['sync_events']['data']
else:
warnings.warn("No bmi3d or sync events present!")
preproc_errors.append("No bmi3d or sync events present!")
data['events'] = np.zeros((0,), dtype=[('timestamp', 'f8'), ('code', 'u1'), ('event', 'S32'), ('data', 'u4')])
# Add timestamp fields to the events
if 'timestamp_sync' in corrected_clock.dtype.names:
timestamp_sync = np.asarray([corrected_clock['timestamp_sync'][cycle] for cycle in corrected_events['time']])
else:
timestamp_sync = np.nan*np.zeros((len(corrected_events),))
if 'timestamp_bmi3d' in corrected_clock.dtype.names:
timestamp_bmi3d = np.asarray([corrected_clock['timestamp_bmi3d'][cycle] for cycle in corrected_events['time']])
else:
timestamp_bmi3d = np.nan*np.zeros((len(corrected_events),))
if 'timestamp_measure_offline' in corrected_clock.dtype.names:
timestamp_measure = np.asarray([corrected_clock['timestamp_measure_offline'][cycle] for cycle in corrected_events['time']])
else:
timestamp_measure = np.nan*np.zeros((len(corrected_events),))
# And keep a copy of each clock for convenience
corrected_events['timestamp_measure'] = timestamp_measure
corrected_events['timestamp_sync'] = timestamp_sync
corrected_events['timestamp_bmi3d'] = timestamp_bmi3d
corrected_events['timestamp'] = timestamp_sync # the default for simplicity
# Check the integrity of the sync events from all the sources
if 'bmi3d_events' in data and 'sync_events' in data:
if len(data['sync_events']) != len(corrected_events):
warnings.warn(f"Number of sync events ({len(data['sync_events'])}) doesn't match number of bmi3d events ({len(corrected_events)}).")
preproc_errors.append(f"Number of sync events ({len(data['sync_events'])}) doesn't match number of bmi3d events ({len(corrected_events)}).")
# Add task data
if 'bmi3d_task' in data:
task = data['bmi3d_task']
# special handling for an old version of the tracking task
if 'generator' in metadata and metadata['generator']=='tracking_target_chain':
if 'current_target' in task.dtype.names:
# task data from bmi3d has bugs in saved reference and disturbance
task = _correct_tracking_task_data(data, metadata)
elif 'timestamp_sync' in corrected_clock.dtype.names:
warnings.warn("No task data found! Reconstructing from sync data")
preproc_errors.append("No hdf task data found! Attempted to reconstruct from sync data")
# Reconstruct the task data from available data
task = np.zeros((len(corrected_clock),), dtype=[('time', 'u8'), ('cursor', 'f8', (3,)),])
task['time'] = np.arange(len(corrected_clock))
task['cursor'] = 0
if 'cursor_analog_cm' in data:
time = np.arange(len(data['cursor_analog_cm']))/metadata['cursor_analog_samplerate']
samples = np.searchsorted(time, corrected_clock['timestamp_sync'])
task['cursor'][:,[0,2]] = data['cursor_analog_cm'][samples]
else:
warnings.warn("No task data found!")
preproc_errors.append("No hdf task data found!")
task = np.zeros((0,), dtype=[('time', 'f8'), ('cursor', 'f8', (3,))])
# Compare the task and clock data
if abs(len(task) - len(corrected_clock)) == 1:
corrected_clock = corrected_clock[:len(task)]
task = task[:len(corrected_clock)]
elif len(task) != len(corrected_clock):
warnings.warn(f"Number of task cycles ({len(task)}) doesn't match number of clock cycles ({len(corrected_clock)}).")
preproc_errors.append(f"Number of task cycles ({len(task)}) doesn't match number of clock cycles ({len(corrected_clock)}).")
data.update({
'task': task,
'clock': corrected_clock,
'events': corrected_events,
})
# In some versions of BMI3D, hand position contained erroneous data
# caused by `np.empty()` instead of `np.nan`. The 'clean_hand_position'
# replaces these bad data with `np.nan`.
if metadata['sync_protocol_version'] < 14 and isinstance(task, np.ndarray) and 'manual_input' in task.dtype.names:
clean_hand_position = _correct_hand_traj(task['manual_input'], task['cursor'])
if np.count_nonzero(~np.isnan(clean_hand_position)) > 2*clean_hand_position.ndim:
data['clean_hand_position'] = clean_hand_position
elif isinstance(task, np.ndarray) and 'manual_input' in task.dtype.names:
data['clean_hand_position'] = task['manual_input']
# Interpolate clean hand kinematics
if ('timestamp_sync' in corrected_clock.dtype.names and
'clean_hand_position' in data and
len(data['clean_hand_position']) > 0 and
np.count_nonzero(~np.isnan(data['clean_hand_position'])) > 0):
metadata['hand_interp_samplerate'] = 1000
data['hand_interp'] = aodata.get_interp_task_data(data, metadata, datatype='hand', samplerate=metadata['hand_interp_samplerate'])
# And interpolated cursor kinematics
if ('timestamp_sync' in corrected_clock.dtype.names and
len(corrected_clock) > 0 and
'cursor' in task.dtype.names and
len(task['cursor']) > 0 and
np.count_nonzero(~np.isnan(task['cursor'])) > 0):
metadata['cursor_interp_samplerate'] = 1000
data['cursor_interp'] = aodata.get_interp_task_data(data, metadata, datatype='cursor', samplerate=metadata['cursor_interp_samplerate'])
metadata['preproc_errors'] = preproc_errors
return data, metadata
[docs]def get_peak_power_mW(exp_metadata):
"""
Estimate the peak power from the date
Args:
exp_metadata (dict): bmi3d metadata
Returns:
float: peak power in mW
"""
date = datetime.fromisoformat(exp_metadata['date']).date()
if 'qwalor_peak_watts' in exp_metadata:
peak_power_mW = exp_metadata['qwalor_peak_watts']
elif date < datetime(2022,5,31).date():
if 'qwalor_channel' in exp_metadata and exp_metadata['qwalor_channel'] == 4:
peak_power_mW = 1.5
else:
peak_power_mW = 20
elif date < datetime(2022,9,30).date():
peak_power_mW = 1.5
elif date < datetime(2023,1,23).date():
peak_power_mW = 20
else:
peak_power_mW = 25
return peak_power_mW
def _get_laser_trial_times_old_data(exp_data, exp_metadata, laser_sensor='qwalor_sensor',
calibration_file='qwalor_447nm_ch2.yaml', debug=False, **kwargs):
'''
Get the laser trial times, trial widths, and trial powers from the given experiment. Returned
values are computed from the laser sensor in combination with the expected laser events from
BMI3D's hdf records. Not recommended for use with experiments with sync_protocol_version > 12.
Args:
exp_data (dict): bmi3d data
exp_metadata (dict): bmi3d metadata
laser_sensor (str, optional): Specifies the name of the analog laser sensor
calibration_file (str, optional): Specifies the name of the calibration file for the laser sensor
debug (bool, optional): print a plot of the laser sensor aligned to the computed times
kwargs (dict): to be passed to `:func:~aopy.preproc.laser.find_stim_times`
Returns:
tuple: tuple containing:
| **corrected_times (nevent):** corrected laser timings (seconds)
| **corrected_widths (nevent):** corrected laser widths (seconds)
| **corrected_powers (nevent):** corrected laser powers (fraction of maximum)
| **times_not_found (nevent):** boolean array of times without onset and offset sensor measurements
| **widths_above_thr (nevent):** boolean array of widths above the given threshold from the expected width
| **powers_above_thr (nevent):** boolean array of powers above the given threshold from the expected power
'''
if laser_sensor not in exp_data:
raise ValueError(f"Could not find laser sensor data ({laser_sensor}). Try preprocessing the data first")
# Some older experiments didn't have laser trigger data. Instead, we estimate
# the laser event timing from BMI3D's sync events, then use the laser sensor to find
# a more accurate time. However, in some cases the laser sensor data is too noisy so
# we fall back on these inaccurate times.
events = exp_data['sync_events']['event']
event_times = exp_data['sync_events']['timestamp']
times = event_times[events == b'TRIAL_START']
min_isi = np.min(np.diff(times))
if 'search_radius' not in kwargs:
kwargs['search_radius'] = min_isi/2 # look for sensor measurements up to half the minimum ISI away
# Get width and power from the 'trials' data
if 'bmi3d_trials' in exp_data:
trials = exp_data['bmi3d_trials']
gains = trials['power'][:len(times)]
edges = trials['edges'][:len(times)]
widths = np.array([t[1] - t[0] for t in edges])
thr_width=0.001
thr_power=0.05
else:
# In the very old experiments, the laser width and power were not recorded. Since we
# can't use the exp data as ground truth, so just trust the analog sensor data. Timing may be off.
widths = np.zeros((len(times),))
gains = np.ones((len(times),))
thr_width = 999
thr_power = 1
# Correct the event timings using the sensor data
sensor_data = exp_data[laser_sensor]
sensor_voltsperbit = exp_metadata['analog_voltsperbit']
samplerate = exp_metadata['analog_samplerate']
peak_power_mW = get_peak_power_mW(exp_metadata)
(corrected_times, corrected_widths, corrected_powers, times_not_found, widths_above_thr,
powers_above_thr) = laser.find_stim_times(times, widths, gains, sensor_data,
samplerate, sensor_voltsperbit, peak_power_mW, thr_width=thr_width,
thr_power=thr_power, calibration_file=calibration_file, debug=debug, **kwargs)
if np.sum(times_not_found) > 0:
warnings.warn(f"{np.sum(times_not_found)} laser trials missing onset and/or offset sensor measurements")
if np.sum(widths_above_thr) > 0:
warnings.warn(f"{np.sum(widths_above_thr)} laser trials have widths above the given threshold")
if np.sum(powers_above_thr) > 0:
warnings.warn(f"{np.sum(powers_above_thr)} laser trials have powers above the given threshold")
return corrected_times, corrected_widths, gains, corrected_powers
def _get_laser_trial_times(exp_data, exp_metadata, laser_trigger='qwalor_trigger',
laser_sensor='qwalor_sensor', debug=False, **kwargs):
'''
Get the laser trial times, trial widths, and trial powers from the given experiment. Returned
values are computed from the laser sensor in combination with the expected laser events from
BMI3D's hdf records. Not recommended for use with experiments with sync_protocol_version > 12.
Args:
exp_data (dict): bmi3d data
exp_metadata (dict): bmi3d metadata
laser_trigger (str, optional): Specifies the name of the digital laser trigger
laser_sensor (str, optional): Specifies the name of the analog laser sensor
debug (bool, optional): print a plot of the laser sensor aligned to the computed times
kwargs (dict): to be passed to `:func:~aopy.preproc.laser.find_stim_times`
Returns:
tuple: tuple containing:
| **times (nevent):** laser timings (seconds)
| **widths (nevent):** laser widths (seconds)
| **gains (nevent):** laser gains (fraction)
| **powers (nevent):** calibrated laser powers (mW)
'''
# Use the digital trigger as the ground truth of timing
timestamps = exp_data[laser_trigger]['timestamp']
values = exp_data[laser_trigger]['value']
times = timestamps[values == 1]
widths = timestamps[values == 0] - timestamps[values == 1]
# Figure out the intended gain of each pulse
if 'bmi3d_trials' in exp_data and 'power' in exp_data['bmi3d_trials'].dtype.names:
trials = exp_data['bmi3d_trials']
gains = trials['power'][:len(times)]
elif 'laser_power' in exp_metadata:
try:
gains = np.ones((len(times),)) * exp_metadata['laser_power']
except:
gains = np.ones((len(times),))
else:
# In the very old experiments, the laser width and power were not recorded.
gains = np.ones((len(times),))
sensor_data = exp_data[laser_sensor]
sensor_voltsperbit = exp_metadata['analog_voltsperbit']
samplerate = exp_metadata['analog_samplerate']
laser_on_times = np.vstack([timestamps[values == 1], timestamps[values == 0]]).T
laser_on_samples = (laser_on_times * samplerate).astype(int)
laser_sensor_values = np.array([np.median(sensor_data[laser_on_samples[t,0]:laser_on_samples[t,1]])
for t in range(len(laser_on_samples))], dtype='float')
# Estimate the peak power from the date
peak_power_mW = get_peak_power_mW(exp_metadata)
powers = laser.calibrate_sensor(laser_sensor_values * sensor_voltsperbit, peak_power_mW, **kwargs)
if debug:
print(f"eCube recorded {len(times)} stims")
plt.figure()
visualization.plot_laser_sensor_alignment(sensor_data*sensor_voltsperbit, samplerate, times)
return times, widths, gains, powers
[docs]def get_laser_trial_times(preproc_dir, subject, te_id, date, laser_trigger='qwalor_trigger',
laser_sensor='qwalor_sensor', debug=False, **kwargs):
'''
Get the laser trial times, trial widths, and trial powers from the given experiment. Returned
values are computed from the laser sensor in combination with the expected laser events from
BMI3D's hdf records.
Args:
preproc_dir (str): base directory where the files live
subject (str): Subject name
te_id (int): Block number of Task entry object
date (str): Date of recording
laser_trigger (str, optional): Specifies the name of the digital laser trigger
laser_sensor (str, optional): Specifies the name of the analog laser sensor
kwargs (dict): to be passed to `:func:~aopy.preproc.laser.find_stim_times`
Returns:
tuple: tuple containing:
| **times (nevent):** laser timings (seconds)
| **widths (nevent):** laser widths (seconds)
| **gains (nevent):** laser gains (fraction)
| **powers (nevent):** calibrated laser powers (mW)
'''
exp_data, exp_metadata = aodata.load_preproc_exp_data(preproc_dir, subject, te_id, date)
# Load the sensor data if it's not already in the bmi3d data
if laser_sensor not in exp_data:
files, data_dir = aodata.get_source_files(preproc_dir, subject, te_id, date)
hdf_filepath = os.path.join(data_dir, files['hdf'])
if not os.path.exists(hdf_filepath):
raise FileNotFoundError(f"Could not find raw files for te {te_id} ({hdf_filepath})")
exp_data, exp_metadata = parse_bmi3d(data_dir, files)
# Experiments need to be handled differently depending on whether the trigger was recorded
if laser_trigger in exp_data:
# Return ground truth timestamps of when the laser should have been turned on
return _get_laser_trial_times(exp_data, exp_metadata, laser_trigger=laser_trigger,
laser_sensor=laser_sensor, debug=debug, **kwargs)
else:
# Use the bmi3d events as an estimate of timing, then locate the nearby sensor measurements
return _get_laser_trial_times_old_data(exp_data, exp_metadata, laser_sensor=laser_sensor,
debug=debug, **kwargs)
[docs]def get_switched_stimulation_sites(preproc_dir, subject, te_id, date, trigger_timestamps,
return_switch_ch=False, debug=False):
'''
Get the stimulation sites at the given timestamps from an experiment where an optical switch was used.
Args:
preproc_dir (str): base directory where the files live
subject (str): Subject name
te_id (int): Block number of Task entry object
date (str): Date of recording
trigger_timestamps (nt,): timestamps of interest
return_switch_ch (bool, optional): also return the switch channel at the computed times
debug (bool, optional): print a plot of the optical switch channel at the computed times
Returns:
(nt,): stimulation sites at the given timestamps, or np.nan if no site was selected.
'''
exp_data, exp_metadata = aodata.load_preproc_exp_data(preproc_dir, subject, te_id, date)
# Check that the optical switch was present
if 'qwalor_switch_rdy_dch' not in exp_metadata:
raise ValueError("No optical switch data found in the experiment")
# Load the optical switch data if it's not already in the bmi3d data
if 'optical_switch' not in exp_data:
files, data_dir = aodata.get_source_files(preproc_dir, subject, te_id, date)
hdf_filepath = os.path.join(data_dir, files['hdf'])
if not os.path.exists(hdf_filepath):
raise FileNotFoundError(f"Could not find raw files for te {te_id} ({hdf_filepath})")
exp_data, exp_metadata = parse_bmi3d(data_dir, files)
optical_switch = exp_data['optical_switch']
optical_switch_timestamps = optical_switch['timestamp']
optical_switch_channels = optical_switch['channel']
switch_channels = np.zeros((len(trigger_timestamps),), dtype='float')
for i, t in enumerate(trigger_timestamps):
# Find the most recent switch time before the trigger time
idx = np.searchsorted(optical_switch_timestamps, t, side='right') - 1
if idx < 0: # before the first switch
switch_channels[idx] = optical_switch_channels[0]
elif len(optical_switch_timestamps) > idx + 1 and optical_switch_timestamps[idx+1] < t:
switch_channels[idx] = 0
else:
switch_channels[i] = optical_switch_channels[idx]
switch_channels[switch_channels <= 0] = np.nan # no site selected
switch_channels -= 1 # 1-indexed to 0-indexed
stimulation_site = [exp_metadata['stimulation_site'][int(ch)] if not np.isnan(ch) else None for ch in switch_channels]
if debug:
plt.figure()
plt.step(optical_switch_timestamps, optical_switch_channels, where='post')
plt.plot(trigger_timestamps, switch_channels + 1, 'ro')
for i, txt in enumerate(stimulation_site):
plt.text(trigger_timestamps[i], switch_channels[i] + 1, txt, fontsize=6, ha='center', va='center', color='w')
plt.xlabel('time (s)')
plt.ylabel('switch channel (1-indexed)')
if return_switch_ch:
return stimulation_site, switch_channels
else:
return stimulation_site
[docs]def get_target_events(exp_data, exp_metadata):
'''
For target acquisition tasks, get an (n_event, n_target) array encoding the position
of each target whenever an event is fired by BMI3D. The resulting sequence is used
to generate a sampled timeseries in :func:`~aopy.data.bmi3d.get_kinematic_segments`.
When targets are turned off, their position is replaced by np.nan.
Args:
exp_data (dict): A dictionary containing the experiment data.
exp_metadata (dict): A dictionary containing the experiment metadata.
Returns:
(n_event, n_target, 3) array: position of each target at each event time.
'''
events = exp_data['events']['event']
event_data = exp_data['events']['data']
trials = exp_data['bmi3d_trials']
target_idx, location_idx = np.unique(trials['index'], axis=0, return_index=True)
locations = [np.round(t[[0,2,1]], 4) for t in trials['target'][location_idx]]
# Generate events for each unique target
target_events = []
for idx in range(len(locations)):
target_location = locations[idx]
# Create a nan mask encoding when each target is turned on
target_on = np.zeros((len(events),))
on = np.nan
for event_idx, (event, data) in enumerate(zip(events, event_data)):
if (event == b'TARGET_ON') and (data == target_idx[idx]):
on = 1
elif event == b'TRIAL_END' or ((event == b'TARGET_OFF') and (data == target_idx[idx])):
on = np.nan
target_on[event_idx] = on
# Set the non-nan values to the target location
event_target = target_location[None,:] * target_on[:,None]
target_events.append(event_target)
return np.array(target_events).transpose(1,0,2)
[docs]def get_ref_dis_frequencies(data, metadata):
'''
For continuous tracking tasks, get the set of frequencies (in Hz) used to
generate the reference and disturbance trajectories that were preesented
on each trial of the experiment.
Note:
This function should be used with caution on task entries that have mismatched sync and bmi3d events!
Prior to 11-16-2022, bmi3d did not allow the number of experimental frequencies to be set by the experimenter,
and this parameter defaulted to 8.
Prior to 2-23-2023, bmi3d did not save the generator index in the task data, and this had to be calculated
by the number of times bmi3d entered the 'wait' state.
Args:
data (dict): A dictionary containing the experiment data.
metadata (dict): A dictionary containing the experiment metadata.
Returns:
tuple: Tuple containing:
| **freq_r (list of arrays):** (ntrial) list of (nfreq,) frequencies used to generate reference trajectory
| **freq_d (list of arrays):** (ntrial) list of (nfreq,) frequencies used to generate disturbance trajectory
Examples:
.. code-block:: python
subject = 'test'
te_id = '8461'
date = '2023-02-25'
data, metadata = load_preproc_exp_data(data_dir, subject, te_id, date)
freq_r, freq_d = get_ref_dis_frequencies(data, metadata)
plt.figure()
plt.plot(freq_r, 'darkorange')
plt.plot(freq_d, 'tab:red', linestyle='--')
plt.xlabel('Trial #'); plt.ylabel('Frequency (Hz)')
.. image:: _images/get_ref_dis_freqs_test.png
.. code-block:: python
subject = 'churro'
te_id = '375'
date = '2023-10-02'
data, metadata = load_preproc_exp_data(data_dir, subject, te_id, date)
freq_r, freq_d = get_ref_dis_frequencies(data, metadata)
plt.figure()
plt.plot(freq_r, 'darkorange')
plt.plot(freq_d, 'tab:red', linestyle='--')
plt.xlabel('Trial #'); plt.ylabel('Frequency (Hz)')
.. image:: _images/get_ref_dis_freqs_churro.png
'''
# grab params relevant for generator
params = json.loads(metadata['sequence_params'])
if 'num_primes' not in params.keys():
params['num_primes'] = 8
primes = np.asarray(list(sympy.primerange(0, sympy.prime(params['num_primes'])+1)))
even_idx = np.arange(len(primes))[0::2]
odd_idx = np.arange(len(primes))[1::2]
base_period = 20
# recreate random trial order of reference & disturbance frequencies
np.random.seed(params['seed'])
o = np.random.rand(params['ntrials'],primes.size) # phase offset - need to generate this like in bmi3d to reproduce correct random order
order = np.random.choice([0,1])
if order == 0:
trial_r_idx = np.array([even_idx, odd_idx]*params['ntrials'], dtype='object')
trial_d_idx = np.array([odd_idx, even_idx]*params['ntrials'], dtype='object')
elif order == 1:
trial_r_idx = np.array([odd_idx, even_idx]*params['ntrials'], dtype='object')
trial_d_idx = np.array([even_idx, odd_idx]*params['ntrials'], dtype='object')
# get trial segments
events = data['bmi3d_events']['code']
cycles = data['bmi3d_events']['time'] # bmi3d cycle number
start_codes = [metadata['event_sync_dict']['TARGET_ON']]
if 'PAUSE_START' in metadata['event_sync_dict']:
end_codes = [metadata['event_sync_dict']['TRIAL_END'], metadata['event_sync_dict']['PAUSE_START']]
else:
end_codes = [metadata['event_sync_dict']['TRIAL_END'], metadata['event_sync_dict']['PAUSE']]
_, segment_cycles = base.get_trial_segments(events, cycles, start_codes, end_codes)
# get trajectory generator index used for each trial
if 'gen_idx' in data['task'].dtype.names:
# get generator index from task data (saved on every bmi3d cycle)
generator_segments = np.array([data['task']['gen_idx'][cycle[0]:cycle[-1]] for cycle in segment_cycles], dtype='object')
assert np.all([gen[0]==gen[-1] for gen in generator_segments]), 'Generator index is not consistent throughout trial segment!'
generator_idx = [int(gen[0]) for gen in generator_segments]
assert (np.diff(generator_idx) >= 0).all(), 'Generator index should stay the same or increase over trials, never decrease!'
else:
# get generator index from number of previous wait states (wait state parses next trial)
states = data['bmi3d_state']['msg']
state_cycles = data['bmi3d_state']['time'] # bmi3d cycle number
generator_idx = [sum(states[state_cycles <= cycle[0]] == b'wait')-1 for cycle in segment_cycles]
assert (np.diff(generator_idx) >= 0).all(), 'Generator index should stay the same or increase over trials, never decrease!'
# use generator index to get reference & disturbance frequencies for each trial
freq_r = [primes[np.array(idx, dtype=int)]/base_period for idx in trial_r_idx[generator_idx]]
freq_d = [primes[np.array(idx, dtype=int)]/base_period for idx in trial_d_idx[generator_idx]]
return freq_r, freq_d