Source code for ephyspy.features.spike_features

#!/usr/bin/env python3
# Copyright 2023 Jonas Beck

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Dict, Optional

import numpy as np
from matplotlib.axes import Axes

from ephyspy.features.base import SpikeFeature
from ephyspy.features.utils import fetch_available_fts
from ephyspy.utils import (
    fwhm,
    has_spike_feature,
    is_spike_feature,
    scatter_spike_ft,
    unpack,
)


[docs]def available_spike_features(**kwargs) -> Dict[str, SpikeFeature]: """Return a dictionary of all implemented spike features. Looks for all classes that inherit from SpikeFeature and returns a dictionary of all available features. If compute_at_init is True, the features are computed at initialization. Returns: dict[str, SpikeFeature]: Dictionary of all available spike features. """ all_features = fetch_available_fts() features = {ft.__name__.lower(): ft for ft in all_features if is_spike_feature(ft)} features = {k.replace("spike_", ""): v for k, v in features.items()} if len(kwargs) > 0: return { k: lambda *default_args, **default_kwargs: v( *default_args, **default_kwargs, **kwargs, ) for k, v in features.items() } else: return features
[docs]class Spike_AP_upstroke(SpikeFeature): """Extract spike level upstroke feature. depends on: /. description: upstroke of AP. units: V/s. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): upstroke = self.lookup_spike_feature("upstroke", recompute=recompute) upstroke_v = self.lookup_spike_feature("upstroke_v", recompute=recompute) upstroke_t = self.lookup_spike_feature("upstroke_t", recompute=recompute) upstroke_idx = self.lookup_spike_feature("upstroke_index", recompute=recompute) if store_diagnostics: self._update_diagnostics( { "upstroke_t": upstroke_t, "upstroke_idx": upstroke_idx, "upstroke_v": upstroke_v, } ) return upstroke def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: idxs = slice(None) if selected_idxs is None else selected_idxs up_t, up_v = unpack(self.diagnostics, ["upstroke_t", "upstroke_v"]) up_dvdt = self.value * 1e3 T = 15e-5 t = np.linspace(up_t[idxs] - T, up_t[idxs] + T, 2) ax = scatter_spike_ft( "upstroke", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs ) kwargs["color"] = ax._get_lines._cycler_items[0]["color"] ax.plot(t, up_dvdt[idxs] * (t - up_t[idxs]) + up_v[idxs], **kwargs) return ax
[docs]class Spike_AP_downstroke(SpikeFeature): """Extract spike level downstroke feature. depends on: /. description: downstroke of AP. units: V/s. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): downstroke = self.lookup_spike_feature("downstroke", recompute=recompute) downstroke_t = self.lookup_spike_feature("downstroke_t", recompute=recompute) downstroke_v = self.lookup_spike_feature("downstroke_v", recompute=recompute) downstroke_idx = self.lookup_spike_feature( "downstroke_index", recompute=recompute ) if store_diagnostics: self._update_diagnostics( { "downstroke_t": downstroke_t, "downstroke_v": downstroke_v, "downstroke_idx": downstroke_idx, } ) return downstroke def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: idxs = slice(None) if selected_idxs is None else selected_idxs down_t, down_v = unpack(self.diagnostics, ["downstroke_t", "downstroke_v"]) down_dvdt = self.value * 1e3 T = 25e-5 t = np.linspace(down_t[idxs] - T, down_t[idxs] + T, 2) ax = scatter_spike_ft( "downstroke", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs ) kwargs["color"] = ax._get_lines._cycler_items[0]["color"] ax.plot(t, down_dvdt[idxs] * (t - down_t[idxs]) + down_v[idxs], **kwargs) return ax
[docs]class Spike_AP_fast_trough(SpikeFeature): """Extract spike level fast trough feature. depends on: /. description: fast trough of AP. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): fast_trough = self.lookup_spike_feature("fast_trough_v", recompute=recompute) fast_trough_i = self.lookup_spike_feature("fast_trough_i", recompute=recompute) fast_trough_t = self.lookup_spike_feature("fast_trough_t", recompute=recompute) fast_trough_idx = self.lookup_spike_feature( "fast_trough_index", recompute=recompute ) if store_diagnostics: self._update_diagnostics( { "fast_trough_t": fast_trough_t, "fast_trough_i": fast_trough_i, "fast_trough_idx": fast_trough_idx, } ) return fast_trough def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: return scatter_spike_ft( "fast_trough", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs )
[docs]class Spike_AP_slow_trough(SpikeFeature): """Extract spike level slow trough feature. depends on: /. description: slow trough of AP. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): slow_trough = self.lookup_spike_feature("slow_trough_v", recompute=recompute) slow_trough_i = self.lookup_spike_feature("slow_trough_i", recompute=recompute) slow_trough_t = self.lookup_spike_feature("slow_trough_t", recompute=recompute) slow_trough_idx = self.lookup_spike_feature( "slow_trough_index", recompute=recompute ) if store_diagnostics: self._update_diagnostics( { "slow_trough_t": slow_trough_t, "slow_trough_i": slow_trough_i, "slow_trough_idx": slow_trough_idx, } ) return slow_trough def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: return scatter_spike_ft( "slow_trough", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs )
[docs]class Spike_AP_amp(SpikeFeature): """Extract spike level peak height feature. depends on: threshold_v, peak_v. description: v_peak - threshold_v. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): peak_v = self.lookup_spike_feature("peak_v", recompute=recompute) peak_t = self.lookup_spike_feature("peak_t", recompute=recompute) threshold_v = self.lookup_spike_feature("threshold_v", recompute=recompute) peak_height = peak_v - threshold_v if store_diagnostics: self._update_diagnostics( { "peak_v": peak_v, "peak_t": peak_t, "threshold_v": threshold_v, } ) return peak_height if len(peak_v) > 0 else np.array([], dtype=int) def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: if has_spike_feature(self.data, "threshold_v"): idxs = slice(None) if selected_idxs is None else selected_idxs thresh_v, peak_t, peak_v = unpack( self.diagnostics, ["threshold_v", "peak_t", "peak_v"] ) ax.plot(peak_t[idxs], peak_v[idxs], "x", **kwargs) ax.vlines( peak_t[idxs], thresh_v[idxs], peak_v[idxs], ls="--", label="ap_amp", **kwargs, ) return ax
[docs]class Spike_AP_AHP(SpikeFeature): """Extract spike level after hyperpolarization feature. depends on: threshold_v, fast_trough_v. description: v_fast_trough - threshold_v. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): v_fast_trough = self.lookup_spike_feature("fast_trough_v", recompute=recompute) t_fast_trough = self.lookup_spike_feature("fast_trough_t", recompute=recompute) threshold_v = self.lookup_spike_feature("threshold_v", recompute=recompute) threshold_t = self.lookup_spike_feature("threshold_t", recompute=recompute) if store_diagnostics: self._update_diagnostics( { "fast_trough_v": v_fast_trough, "fast_trough_t": t_fast_trough, "threshold_v": threshold_v, "threshold_t": threshold_t, } ) return v_fast_trough - threshold_v def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: if has_spike_feature(self.data, "ap_ahp"): idxs = slice(None) if selected_idxs is None else selected_idxs trough_t, trough_v, threshold_t, threshold_v = unpack( self.diagnostics, ["fast_trough_t", "fast_trough_v", "threshold_t", "threshold_v"], ) ax.vlines( 0.5 * (trough_t[idxs] + threshold_t[idxs]), trough_v[idxs], threshold_v[idxs], ls="--", lw=1, label="ahp", **kwargs, ) return ax
[docs]class Spike_AP_ADP(SpikeFeature): """Extract spike level after depolarization feature. depends on: adp_v, fast_trough_v. description: v_adp - v_fast_trough. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): v_adp = self.lookup_spike_feature("adp_v", recompute=recompute) t_adp = self.lookup_spike_feature("adp_t", recompute=recompute) i_adp = self.lookup_spike_feature("adp_i", recompute=recompute) idx_adp = self.lookup_spike_feature("adp_index", recompute=recompute) v_fast_trough = self.lookup_spike_feature("fast_trough_v", recompute=recompute) t_fast_trough = self.lookup_spike_feature("fast_trough_t", recompute=recompute) if store_diagnostics: self._update_diagnostics( { "adp_v": v_adp, "fast_trough_v": v_fast_trough, "fast_trough_t": t_fast_trough, "adp_t": t_adp, "adp_i": i_adp, "adp_idx": idx_adp, } ) return v_adp - v_fast_trough def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: if has_spike_feature(self.data, "ap_adp"): idxs = slice(None) if selected_idxs is None else selected_idxs adp_t, adp_v, trough_t, trough_v = unpack( self.diagnostics, ["adp_t", "adp_v", "fast_trough_t", "fast_trough_v"], ) ax.vlines( 0.5 * (adp_t[idxs] + trough_t[idxs]), adp_v[idxs], trough_v[idxs], ls="--", lw=1, label="adp", **kwargs, ) return ax
[docs]class Spike_AP_ADP_trough(SpikeFeature): """Extract spike level after depolarization feature. depends on: adp_v. description: |v_adp|. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): v_adp = self.lookup_spike_feature("adp_v", recompute=recompute) t_adp = self.lookup_spike_feature("adp_t", recompute=recompute) i_adp = self.lookup_spike_feature("adp_i", recompute=recompute) idx_adp = self.lookup_spike_feature("adp_index", recompute=recompute) if store_diagnostics: self._update_diagnostics( { "adp_v": v_adp, "adp_t": t_adp, "adp_i": i_adp, "adp_idx": idx_adp, } ) return np.abs(v_adp) def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: if has_spike_feature(self.data, "ap_adp"): idxs = slice(None) if selected_idxs is None else selected_idxs adp_t, adp_v = unpack(self.diagnostics, ["adp_t", "adp_v"]) ax.vlines( adp_t[idxs], adp_v[idxs], 0, ls="--", lw=1, label="adp trough", **kwargs, ) return ax
[docs]class Spike_AP_peak(SpikeFeature): """Extract spike level peak feature. depends on: peak_v. description: max voltage of AP. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): v_peak = self.lookup_spike_feature("peak_v", recompute=recompute) t_peak = self.lookup_spike_feature("peak_t", recompute=recompute) i_peak = self.lookup_spike_feature("peak_i", recompute=recompute) idx_peak = self.lookup_spike_feature("peak_index", recompute=recompute) if store_diagnostics: self._update_diagnostics( { "peak_t": t_peak, "peak_i": i_peak, "peak_idx": idx_peak, } ) return v_peak def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: return scatter_spike_ft( "peak", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs )
[docs]class Spike_AP_overshoot(SpikeFeature): """Extract spike level overshoot feature. depends on: peak_v. description: max voltage of AP above 0 mV. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): v_peak = self.lookup_spike_feature("peak_v", recompute=recompute) t_peak = self.lookup_spike_feature("peak_t", recompute=recompute) i_peak = self.lookup_spike_feature("peak_i", recompute=recompute) idx_peak = self.lookup_spike_feature("peak_index", recompute=recompute) v_peak[v_peak < 0] = float("nan") if store_diagnostics: self._update_diagnostics( { "peak_t": t_peak, "peak_i": i_peak, "peak_idx": idx_peak, } ) return v_peak def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: return scatter_spike_ft( "overshoot", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs )
[docs]class Spike_AP_thresh(SpikeFeature): """Extract spike level ap threshold feature. depends on: threshold_v. description: For details on how AP thresholds are computed see AllenSDK. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): v_thresh = self.lookup_spike_feature("threshold_v", recompute=recompute) t_thresh = self.lookup_spike_feature("threshold_t", recompute=recompute) i_thresh = self.lookup_spike_feature("threshold_i", recompute=recompute) idx_thresh = self.lookup_spike_feature("threshold_index", recompute=recompute) if store_diagnostics: self._update_diagnostics( { "threshold_t": t_thresh, "threshold_i": i_thresh, "threshold_idx": idx_thresh, } ) return v_thresh def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: return scatter_spike_ft( "threshold", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs )
[docs]class Spike_AP_trough(SpikeFeature): """Extract spike level ap trough feature. depends on: through_v. description: For details on how AP troughs are computed see AllenSDK. units: mV. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): v_thresh = self.lookup_spike_feature("trough_v", recompute=recompute) t_thresh = self.lookup_spike_feature("trough_t", recompute=recompute) i_thresh = self.lookup_spike_feature("trough_i", recompute=recompute) idx_thresh = self.lookup_spike_feature("trough_index", recompute=recompute) if store_diagnostics: self._update_diagnostics( { "trough_t": t_thresh, "trough_i": i_thresh, "trough_idx": idx_thresh, } ) return v_thresh def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: return scatter_spike_ft( "trough", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs )
[docs]class Spike_AP_width(SpikeFeature): """Extract spike level ap width feature. depends on: width. description: full width half max of AP. units: s. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): width = self.lookup_spike_feature("width", recompute=recompute) trough_idxs = self.lookup_spike_feature("trough_index").astype(int) spike_idxs = self.lookup_spike_feature("threshold_index").astype(int) peak_idxs = self.lookup_spike_feature("peak_index").astype(int) if store_diagnostics: self._update_diagnostics( { "trough_idx": trough_idxs, "spike_idx": spike_idxs, "peak_idx": peak_idxs, } ) return width def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: if has_spike_feature(self.data, "width"): idxs = slice(None) if selected_idxs is None else selected_idxs # the following is adapted from `allen_sdk.ephys_features.find_widths` trough_idxs, spike_idxs, peak_idxs = unpack( self.diagnostics, ["trough_idx", "spike_idx", "peak_idx"] ) t = self.data.t v = self.data.v ap_height = v[peak_idxs] - v[trough_idxs] trough_fwhm = ap_height / 2.0 + v[trough_idxs] thresh_fwhm = (v[peak_idxs] - v[spike_idxs]) / 2.0 + v[spike_idxs] # Some spikes in burst may have deep trough but short height, so can't use same # definition for width fwhm = trough_fwhm.copy() fwhm[trough_fwhm < v[spike_idxs]] = thresh_fwhm[trough_fwhm < v[spike_idxs]] width_idx = np.array( [ ( pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0] if np.flatnonzero(v[pk:spk:-1] <= wl).size > 0 else np.nan ) for pk, spk, wl in zip( peak_idxs, spike_idxs, fwhm, ) ] ).astype(int) fwhm = fwhm[idxs] width_t = t[width_idx][idxs] width = self.lookup_spike_feature("width")[idxs] ax.hlines(fwhm, width_t, width_t + width, label="width", ls="--", **kwargs) return ax
[docs]class Spike_AP_UDR(SpikeFeature): """Extract spike level ap udr feature. depends on: upstroke, downstroke. description: upstroke / downstroke. For details on how upstroke, downstroke are computed see AllenSDK. units: /. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): upstroke = self.lookup_spike_feature("upstroke", recompute=recompute) upstroke_t = self.lookup_spike_feature("upstroke_t", recompute=recompute) upstroke_v = self.lookup_spike_feature("upstroke_v", recompute=recompute) downstroke = self.lookup_spike_feature("downstroke", recompute=recompute) downstroke_t = self.lookup_spike_feature("downstroke_t", recompute=recompute) downstroke_v = self.lookup_spike_feature("downstroke_v", recompute=recompute) if store_diagnostics: self._update_diagnostics( { "upstroke": upstroke, "upstroke_v": upstroke_v, "upstroke_t": upstroke_t, "downstroke": downstroke, "downstroke_v": downstroke_v, "downstroke_t": downstroke_t, } ) return upstroke / -downstroke def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: if has_spike_feature(self.data, "threshold_t"): idxs = slice(None) if selected_idxs is None else selected_idxs up_t, up_v, down_t, down_v = unpack( self.diagnostics, ["upstroke_t", "upstroke_v", "downstroke_t", "downstroke_v"], ) ax.plot(up_t[idxs], up_v[idxs], "x", label="upstroke", **kwargs) ax.plot(down_t[idxs], down_v[idxs], "x", label="upstroke", **kwargs) return ax
[docs]class Spike_ISI(SpikeFeature): """Extract spike level inter-spike-interval feature. depends on: threshold_t. description: The distance between subsequent spike thresholds. isi at the first index is nan since isi[t+1] = threshold_t[t+1] - threshold_t[t]. units: s. """ def __init__(self, data=None, **kwargs): super().__init__(data, **kwargs) def _compute(self, recompute=False, store_diagnostics=True): isi = np.array([], dtype=int) spike_times = self.lookup_spike_feature("threshold_t", recompute=recompute) spike_thresh = self.lookup_spike_feature("threshold_v", recompute=recompute) if len(spike_times) > 1: isi = np.diff(spike_times) isi = np.insert(isi, 0, 0) elif len(spike_times) == 1: isi = np.array([float("nan")]) if store_diagnostics: self._update_diagnostics( { "spike_times": spike_times, "spike_thresh": spike_thresh, "isi": isi, } ) return isi def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes: if has_spike_feature(self.data, "isi"): idxs = slice(None) if selected_idxs is None else selected_idxs thresh_t, thresh_v, isi = unpack( self.diagnostics, ["spike_times", "spike_thresh", "isi"] ) thresh_t = thresh_t[idxs] thresh_v = thresh_v[idxs] isi = isi[idxs] ax.hlines( thresh_v, thresh_t - isi, thresh_t, ls="--", label="isi", **kwargs ) ax.plot(thresh_t, thresh_v, "x", **kwargs) return ax