Toolkit documentation¶
dartsort is also a toolkit for common spike sorting analyses, and the main spike sorter is built out of those tools. The most important tools are the "Peelers", which detect spikes and featurize them. These rely on featurization pipelines. Other tools include clustering and template waveform estimation from spike trains.
Detecting, cleaning, and featurizing spikes ("peeling")¶
In spike sorting workflows like template matching or thresholding-based spike detection, there is a common sequence of steps: spikes are detected, there is possibly some iterative subtraction of estimated clean waveforms, and finally waveforms are extracted and featurized.
In dartsort those kind of workflows are called "peelers" (inspired by the iterative subtraction of events) and are implemented as subclasses of a BasePeeler class which handles the shared logic (processing chunks of the recording in parallel, fitting featurization models, ...).
All of the peelers accept a featurization pipeline or a FeaturizationConfig object which handles the spike featurization; these are discussed in the next section.
dartsort includes high-level functions for running various kinds of peelers and corresponding configuration objects. These are in sections below.
Template matching¶
match() runs template matching from known templates if its template_data parameter is set, or else estimating templates from sorting using estimate_template_library().
It is configured by the matching_cfg argument.
dartsort.match ¶
match(output_dir: str | Path, recording: BaseRecording, sorting: DARTsortSorting | None = None, motion: MotionInfo | None = None, waveform_cfg: WaveformConfig = default_waveform_cfg, template_cfg=default_template_cfg, featurization_cfg: FeaturizationConfig = default_featurization_cfg, matching_cfg=default_matching_cfg, sampling_cfg: FitSamplingConfig = default_peeling_fit_sampling_cfg, previous_detection_cfg: Any | None = None, prev_step_name: str | None = None, save_cfg: DARTsortInternalConfig | None = None, chunk_starts_samples=None, stop_after_n_spikes: int | None = None, ensure_coverage: float | None = None, overwrite=False, residual_filename: str | None = None, skip_resid_snips=False, show_progress=True, hdf5_filename='matching0.h5', model_subdir='matching0_models', template_data: TemplateData | None = None, template_npz='template_data.npz', computation_cfg: ComputationConfig | None = None, template_denoising_tsvd=None, whitener: Whitener | None = None) -> DARTsortSorting
dartsort.MatchingConfig ¶
Template matching pursuit parameters
template_svd_compression_rank
class-attribute
instance-attribute
¶
template_svd_compression_rank: int = 5
template_svd_compression_min_explained_variance
class-attribute
instance-attribute
¶
template_svd_compression_min_explained_variance: float = 0.005
template_min_channel_amplitude
class-attribute
instance-attribute
¶
template_min_channel_amplitude: float = 1.0
amplitude_scaling_variance
class-attribute
instance-attribute
¶
amplitude_scaling_variance: float = 0.01 ** 2
amplitude_scaling_boundary
class-attribute
instance-attribute
¶
amplitude_scaling_boundary: float = 1.0 / 3.0
coarse_approx_error_threshold
class-attribute
instance-attribute
¶
coarse_approx_error_threshold: float = 0.0
channel_selection
class-attribute
instance-attribute
¶
channel_selection: Literal['template', 'amplitude'] = 'template'
channel_selection_radius
class-attribute
instance-attribute
¶
channel_selection_radius: float | None = None
template_type
class-attribute
instance-attribute
¶
template_type: Literal['individual_compressed_upsampled', 'drifty', 'debug'] = 'drifty'
up_method
class-attribute
instance-attribute
¶
up_method: Literal['interpolation', 'keys3', 'keys4', 'direct'] = 'keys4'
drift_interp_params
class-attribute
instance-attribute
¶
upsampling_compression_map
class-attribute
instance-attribute
¶
upsampling_compression_map: Literal['yass', 'none'] = 'yass'
whitening
class-attribute
instance-attribute
¶
template_merge_cfg
class-attribute
instance-attribute
¶
template_realignment_cfg
class-attribute
instance-attribute
¶
Neural-net based collision-cleaned spike detection¶
dartsort.subtract ¶
subtract(output_dir: str | Path, recording: BaseRecording, waveform_cfg: WaveformConfig = default_waveform_cfg, featurization_cfg: FeaturizationConfig = default_featurization_cfg, subtraction_cfg=default_subtraction_cfg, sampling_cfg: FitSamplingConfig = default_peeling_fit_sampling_cfg, computation_cfg: ComputationConfig | None = None, chunk_starts_samples=None, stop_after_n_spikes: int | None = None, ensure_coverage: float | None = None, overwrite=False, residual_filename: str | None = None, shuffle: bool = False, show_progress=True, hdf5_filename='subtraction.h5', model_subdir='subtraction_models') -> DARTsortSorting
dartsort.SubtractionConfig ¶
Parameters for neural-net based spike detection
denoiser_realignment_channel
class-attribute
instance-attribute
¶
denoiser_realignment_channel: Literal['detection', 'denoised'] = 'detection'
relative_peak_radius_samples
class-attribute
instance-attribute
¶
relative_peak_radius_samples: int = 5
relative_peak_radius_um
class-attribute
instance-attribute
¶
relative_peak_radius_um: float | None = 35.0
spatial_dedup_radius_um
class-attribute
instance-attribute
¶
spatial_dedup_radius_um: float | None = 50.0
temporal_dedup_radius_samples
class-attribute
instance-attribute
¶
temporal_dedup_radius_samples: int = 7
positive_temporal_dedup_radius_samples
class-attribute
instance-attribute
¶
positive_temporal_dedup_radius_samples: int = 41
residnorm_decrease_threshold
class-attribute
instance-attribute
¶
residnorm_decrease_threshold: float = 9.0
decrease_objective
class-attribute
instance-attribute
¶
decrease_objective: Literal['norm', 'normsq', 'deconv'] = 'deconv'
threshold_before_whitening
class-attribute
instance-attribute
¶
threshold_before_whitening: float = 10.0
whiten_cfg
class-attribute
instance-attribute
¶
subtraction_denoising_cfg
class-attribute
instance-attribute
¶
subtraction_denoising_cfg: FeaturizationConfig = FeaturizationConfig(denoise_only=True, do_nn_denoise=True, extract_radius=200.0, input_waveforms_name='raw', output_waveforms_name='subtracted')
first_denoiser_max_waveforms_fit
class-attribute
instance-attribute
¶
first_denoiser_max_waveforms_fit: int = 512000
first_denoiser_noise_snips
class-attribute
instance-attribute
¶
first_denoiser_noise_snips: int = 100 * 256
first_denoiser_noise_snip_length_mul
class-attribute
instance-attribute
¶
first_denoiser_noise_snip_length_mul: float = 2.5
first_denoiser_noise_density
class-attribute
instance-attribute
¶
first_denoiser_noise_density: float = 0.5
first_denoiser_temporal_jitter
class-attribute
instance-attribute
¶
first_denoiser_temporal_jitter: int = 3
first_denoiser_spatial_jitter
class-attribute
instance-attribute
¶
first_denoiser_spatial_jitter: float = 35.0
Thresholding spike detection¶
dartsort.threshold ¶
threshold(output_dir: str | Path, recording: BaseRecording, waveform_cfg: WaveformConfig = default_waveform_cfg, thresholding_cfg: ThresholdingConfig = default_thresholding_cfg, featurization_cfg: FeaturizationConfig = default_featurization_cfg, featurization_pipeline: WaveformPipeline | None = None, sampling_cfg: FitSamplingConfig = default_peeling_fit_sampling_cfg, extract_channel_index: Tensor | None = None, chunk_starts_samples=None, stop_after_n_spikes: int | None = None, ensure_coverage: float | None = None, overwrite=False, show_progress=True, hdf5_filename='threshold.h5', model_subdir='threshold_models', computation_cfg: ComputationConfig | None = None) -> DARTsortSorting
dartsort.ThresholdingConfig ¶
Parameters for threshold-crossing spike detection
relative_peak_radius_um
class-attribute
instance-attribute
¶
relative_peak_radius_um: float | None = 35.0
relative_peak_radius_samples
class-attribute
instance-attribute
¶
relative_peak_radius_samples: int = 5
Spike extraction and featurization at known event times¶
dartsort.grab ¶
grab(output_dir: str | Path, recording: BaseRecording, sorting: DARTsortSorting, waveform_cfg: WaveformConfig = default_waveform_cfg, featurization_cfg: FeaturizationConfig = default_featurization_cfg, sampling_cfg: FitSamplingConfig = default_peeling_fit_sampling_cfg, chunk_starts_samples=None, overwrite=False, show_progress=True, hdf5_filename='grab.h5', model_subdir='grab_models', computation_cfg: ComputationConfig | None = None) -> DARTsortSorting
Featurization pipelines¶
Featurization is configured by building a FeaturizationConfig.
Inside dartsort, the FeaturizationPipeline will be turned into a WaveformPipeline with its .from_config() constructor.
dartsort.FeaturizationConfig ¶
Featurization and denoising configuration
Parameters for a featurization and denoising pipeline which has the flow: [input waveforms] -> [featurization of input waveforms] -> [denoising] -> [featurization of output waveforms]
The flags below allow users to control which features are computed for the input waveforms, what denoising operations are applied, and what features are computed for the output (post-denoising) waveforms.
Users who'd rather do something not covered by this typical case can manually instantiate a WaveformPipeline and pass it into their peeler.
do_enforce_decrease
class-attribute
instance-attribute
¶
compute_input_tpca_projs_regardless
class-attribute
instance-attribute
¶
compute_input_tpca_projs_regardless: bool = False
localization_amplitude_type
class-attribute
instance-attribute
¶
localization_amplitude_type: Literal['peak', 'ptp'] = 'peak'
localization_model
class-attribute
instance-attribute
¶
localization_model: Literal['pointsource', 'dipole'] = 'pointsource'
additional_com_localization
class-attribute
instance-attribute
¶
additional_com_localization: bool = False
nn_denoiser_class_name
class-attribute
instance-attribute
¶
nn_denoiser_class_name: str = 'Decollider'
nn_denoiser_pretrained_path
class-attribute
instance-attribute
¶
nn_denoiser_pretrained_path: str | None = None
nn_denoiser_extra_kwargs
class-attribute
instance-attribute
¶
nn_denoiser_extra_kwargs: dict | None = argfield(None, cli=False)
learn_cleaned_tpca_basis
class-attribute
instance-attribute
¶
learn_cleaned_tpca_basis: bool = False
input_tpca_waveform_cfg
class-attribute
instance-attribute
¶
pre_gmm_clustering_cfg
class-attribute
instance-attribute
¶
pre_gmm_clustering_cfg: ClusteringConfig | None = None
pre_gmm_refinement_cfgs
class-attribute
instance-attribute
¶
pre_gmm_refinement_cfgs: Sequence[RefinementConfig | None] | None = None
gmm_refinement_cfg
class-attribute
instance-attribute
¶
gmm_refinement_cfg: RefinementConfig | None = None
gmm_clustering_features_cfg
class-attribute
instance-attribute
¶
gmm_clustering_features_cfg: ClusteringFeaturesConfig | None = None
dartsort.WaveformPipeline ¶
Bases: Module
Pipelines of featurization nodes.
from_state_dict_pt
classmethod
¶
Load a pipeline from file.
from_class_names_and_kwargs
classmethod
¶
from_class_names_and_kwargs(geom, channel_index, class_names_and_kwargs, waveform_cfg: WaveformConfig | None, sampling_frequency: float = 30000.0)
Construct a pipeline from a sequence of BaseWaveformModule class names and constructor arguments.
from_config
classmethod
¶
from_config(*, featurization_cfg: FeaturizationConfig, waveform_cfg: WaveformConfig, recording=None, geom=None, channel_index=None, sampling_frequency: float)
Construct a pipeline based on configuration options.
forward ¶
forward(waveforms, *, up_to_index: int | None = None, start_index: int | None = None, **fixed_properties)
Run waveforms and fixed properties through pipeline, extracting features and denoising.
fit ¶
fit(recording: BaseRecording, waveforms: Tensor, computation_cfg: ComputationConfig, *, hdf5_filename: str | Path | None = None, waveforms_dataset_name: str = 'waveforms', **fixed_properties: Tensor)
Fit my transformers in sequence, giving each the outputs of its predecessors.
transform_to_disk ¶
transform_to_disk(hdf5_filename: str | Path, waveforms_dataset_name: str | None = 'waveforms', other_dset_names: Sequence[str] | None = None, start_index: int | None = None, up_to_index: int | None = None)
Save my features to new h5 datasets by running in batches through waveforms saved in h5.
This pipeline is a sequence of denoising and featurization objects from the dartsort.transform module:
dartsort.transform ¶
Spike featurization and denoising pipelines
AmortizedLocalization ¶
Bases: BaseWaveformFeaturizer
Localize spike waveform sources in space with a neural network.
Order of output columns: x, y, z_abs.
local_distances ¶
Return distances from each z to its local geom centered at channels.
AmplitudeFeatures ¶
Bases: BaseWaveformFeaturizer
Extract spike amplitudes.
Decollider ¶
Bases: BaseMultichannelDenoiser
Unsupervised spike waveform denoising.
EnforceDecrease ¶
Bases: BaseWaveformDenoiser
A torch module for enforcing spatial decrease of amplitudes
Calling an instance of this class on a N,T,C batch of waveforms will return a result whose peak-to-peak amplitudes decrease as you move away from the detection channel.
forward ¶
enfdec = EnforceDecrease(geom, channel_index) ... dec_wfs, dec_ptps = enfdec(waveforms, maxchans)
WaveformInterpolator ¶
Bases: BaseWaveformDenoiser
Interpolate waveforms for motion correction.
Localization ¶
Bases: BaseWaveformFeaturizer
Optimization-based spike source localization.
Order of output columns: x, y, z_abs, alpha
DebugMatchingPursuitDenoiser ¶
Bases: BaseWaveformDenoiser
This denoiser is used for testing purposes only.
TruncatedMixtureModelTransformer ¶
Bases: BaseWaveformFeaturizer
Gaussian mixture clustering and classification as a featurization node.
SingleChannelWaveformDenoiser ¶
Bases: BaseWaveformDenoiser
YASS-style single-channel waveform denoising.
SupervisedDenoiser ¶
Bases: BaseMultichannelDenoiser
Supervised multi-channel neural network waveform denoising.
forward_unbatched ¶
Called only at inference time.
BaseTemporalPCA ¶
TemporalPCA ¶
Bases: BaseWaveformAutoencoder, TemporalPCAFeaturizer
Combined spike featurization and denoising with PCA.
TemporalPCADenoiser ¶
TemporalPCAFeaturizer ¶
BaseWaveformModule ¶
Bases: BModule
Base class of all spike featurizers and denoisers.
Clustering¶
dartsort includes configuration options and a main function for running several clustering strategies.
dartsort.cluster ¶
cluster(recording: BaseRecording, sorting: DARTsortSorting, motion: MotionInfo, clustering_cfg: ClusteringConfig | None = default_clustering_cfg, clustering_features_cfg: ClusteringFeaturesConfig | None = default_clustering_features_cfg, refinement_cfgs: Sequence[RefinementConfig | None] | None = None, computation_cfg: ComputationConfig | None = None, features: SimpleMatrixFeatures | None = None, *, _save_cfg: DARTsortInternalConfig | None = None, _save_initial_name='initial', _save_refined_name_fmt='refined0{stepname}', _save_dir=None)
dartsort.ClusteringFeaturesConfig ¶
Parameters to control which features are used for initial clustering
pc_transform
class-attribute
instance-attribute
¶
pc_transform: Literal['log', 'sqrt', 'none'] | None = 'none'
interp_params
class-attribute
instance-attribute
¶
motion_depth_mode
class-attribute
instance-attribute
¶
motion_depth_mode: Literal['channel', 'localization'] = 'channel'
amplitudes_dataset_name
class-attribute
instance-attribute
¶
amplitudes_dataset_name: str = 'denoised_ptp_amplitudes'
voltages_dataset_name
class-attribute
instance-attribute
¶
voltages_dataset_name: str = 'collisioncleaned_voltages'
amplitude_vectors_dataset_name
class-attribute
instance-attribute
¶
amplitude_vectors_dataset_name: str = 'denoised_ptp_amplitude_vectors'
dartsort.ClusteringConfig ¶
dartsort.RefinementConfig ¶
Parameters for clustering refinement
sampling_cfg
class-attribute
instance-attribute
¶
distance_metric
class-attribute
instance-attribute
¶
mixture_steps
class-attribute
instance-attribute
¶
mixture_steps: Sequence[MixtureStep] = ('split', 'merge', 'demolish')
robust_strategy
class-attribute
instance-attribute
¶
robust_strategy: Literal['none', 'fixed'] = 'none'
robust_fixed_std_dataset
class-attribute
instance-attribute
¶
robust_fixed_std_dataset: str = 'collidedness'
demolition_min_resp_ratio
class-attribute
instance-attribute
¶
demolition_min_resp_ratio: float = 0.9
demolish_during_selection
class-attribute
instance-attribute
¶
demolish_during_selection: bool = False
scale_dist_args
class-attribute
instance-attribute
¶
template_merge_cfg
class-attribute
instance-attribute
¶
glom_firing_corr_method
class-attribute
instance-attribute
¶
glom_firing_corr_method: Literal['binsqrt'] = 'binsqrt'
qda_force_merge_for_temp_dist_below
class-attribute
instance-attribute
¶
qda_force_merge_for_temp_dist_below: float = 0.3
spikeinterface_merge_preset
class-attribute
instance-attribute
¶
spikeinterface_merge_max_distance
class-attribute
instance-attribute
¶
spikeinterface_merge_max_distance: float = 0.8
spikeinterface_merge_min_coentropy
class-attribute
instance-attribute
¶
spikeinterface_merge_min_coentropy: float | None = 0.01
spikeinterface_merge_coent_coverage
class-attribute
instance-attribute
¶
spikeinterface_merge_coent_coverage: float = 0.8
spikeinterface_merge_coent_iou
class-attribute
instance-attribute
¶
spikeinterface_merge_coent_iou: float = 0.5
feature_scales
class-attribute
instance-attribute
¶
impute_kind
class-attribute
instance-attribute
¶
impute_kind: Literal['interp', 'impute'] = 'impute'
noise_interp_params
class-attribute
instance-attribute
¶
gmm_isolation_threshold
class-attribute
instance-attribute
¶
gmm_isolation_threshold: float | None = None
gmm_isolation_neighbor_fraction
class-attribute
instance-attribute
¶
gmm_isolation_neighbor_fraction: float = 0.9
Template waveform estimation¶
dartsort.estimate_template_library ¶
estimate_template_library(recording: BaseRecording, sorting: DARTsortSorting, motion: MotionInfo | None = None, min_template_snr: float = 0.0, min_template_ptp: float = 0.0, always_keep_ptp: float = 0.0, min_template_count: int = 0, max_cc_flag_rate: float = 1.0, cc_flag_entropy_cutoff: float = 0.0, waveform_cfg: WaveformConfig = default_waveform_cfg, template_cfg: TemplateConfig = default_template_cfg, realign_cfg: TemplateRealignmentConfig | None = None, template_merge_cfg: TemplateMergeConfig | None = None, tsvd: PCA | TruncatedSVD | None = None, whitener: Whitener | None = None, computation_cfg: ComputationConfig | None = None, fit_featurization_tsvd: bool = False, featurization_cfg: FeaturizationConfig | None = None, detection_cfg: Any | None = None, depth_order: bool = False, template_npz_path=None) -> tuple[DARTsortSorting, TemplateData]
Postprocess spike train and estimate a TemplateData.
dartsort.TemplateData ¶
spike_counts_by_channel
class-attribute
instance-attribute
¶
spike_counts_by_channel: ndarray | None = None
from_config
classmethod
¶
from_config(*, recording: BaseRecording, sorting: DARTsortSorting | None, template_cfg: TemplateConfig, waveform_cfg: WaveformConfig = default_waveform_cfg, save_folder: Path | None = None, overwrite=False, motion: MotionInfo | None = None, whitener: Whitener | None = None, save_npz_name: str | None = 'template_data.npz', tsvd=None, featurization_basis=None, computation_cfg: ComputationConfig | None = None, show_progress: bool = True) -> TemplateData
dartsort.TemplateConfig ¶
Template waveform estimation parameters
algorithm
class-attribute
instance-attribute
¶
denoising_method
class-attribute
instance-attribute
¶
denoising_method: Literal['none', 'exp_weighted', 'svd'] = 'svd'
grab_chunk_length_samples
class-attribute
instance-attribute
¶
grab_chunk_length_samples: int = 30000
template_interp_params
class-attribute
instance-attribute
¶
denoising_fit_sampling_cfg
class-attribute
instance-attribute
¶
denoising_fit_sampling_cfg: FitSamplingConfig = replace(default_peeling_fit_sampling_cfg, n_residual_snips=0)
template_min_channel_amplitude
class-attribute
instance-attribute
¶
template_min_channel_amplitude: float = 1.0
svd_min_explained_variance
class-attribute
instance-attribute
¶
svd_min_explained_variance: float = 0.005
exp_weight_snr_threshold
class-attribute
instance-attribute
¶
exp_weight_snr_threshold: float = 50.0
amplitudes_dataset_name
class-attribute
instance-attribute
¶
amplitudes_dataset_name: str = 'denoised_ptp_amplitudes'