LFADS uses a nonlinear dynamical system (a recurrent neural network) to infer the dynamics underlying observed population activity and to extract ‘denoised’ single-trial firing rates from neural spiking data. It is an unsupervised method to decompose time series data into various factors, such as an initial condition, a generative dynamical system, control inputs to the generator and a low dimension description of the observed data, called the factors. The observations have a noise model (in this case Poisson), so a denoised version of the observations is also created. It is a sequential variational auto-encoder designed specifically for investigating neuroscience data, but can be applied widely to any time series data. In an unsupervised setting, LFADS is able to decompose time series data into various factors, such as an initial condition, a generative dynamical system, control inputs to that generator, and a low dimensional description of the observed data, called the factors. Additionally, the observation model is a loss on a probability distribution, so when LFADS processes a dataset, a denoised version of the dataset is also created. For example, if the dataset is raw spike counts, then under the negative log-likelihood loss under a Poisson distribution, the denoised data would be the inferred Poisson rates.
Inputs: The main data structure is the dataset. This can be labelled anything since the name is accounted for by one of the configuration parameters.
Format of dataset: the top level dictionary is simply name (string -> dictionary). The nest dictionary is the DATA DICTIONARY, which has the following keys: ‘train_data’ and ‘valid_data’ whose values are the corresponding training and validation data with shape ExTxD where E - # examples, T - # time steps, D - # dimensions in data.
The data dictionary also has a few more keys:
‘Train_ext_input’ and ‘valid_ext_input’, if there a known external inputs to the system being modelled, these take on dimensions: ExTxI, where here I is the # dimensions in input.
‘Alignment_matrix_cxf’ - if you are using multiple days data, it’s possible that one can algin the channels. If so, each dataset will contain this matrix, which will be used for both the input adapter and the output adapter for each dataset. These matrices, if provided, must be of size [data dim x factors] where data_dim is the number of neurons recorded on that day, and factors is listed in the configurations.
See the Analysis Repo for more details. An example params file is given which has most default values.
\n config: (yaml) a yaml file containing the following parameters:
kind: Type of model to build {train, posterior_sample_and_average, posterior_push_mean, prior_sample, write_model_params}
output_dist: Type of output distribution {poisson, gaussian”}
data_dir: None. Location of data is re-directed to AWS S3
lfads_save_dir: None. Location of data is re-directed to AWS S3
checkpoint_pb_load_name: Name of
checkpoint files {checkpoint_lve}
checkpoint_name: Name of checkpoint files (.ckpt) {lfads_vae}
output_filename_stem: Name of output file {“”}
max_ckpt_to_keep: Max # of checkpoints to process
max_ckpt_to_keep_lve: Max # of checkpoints to keep for lowest validation error models
ps_nexamples_to_process: Number of examples to process for posterior sample and average (not number of samples to average over)
ext_input_dim: Dimension of external inputs
data_filename_stem: Name of datafile stem
device: “gpu:0"
csv_log: Name of file to keep running log of fit likelihoods etc (.csv appended)
num_steps_for_gen_ic: Number of steps to train the generator initial condition
inject_ext_input_to_gen: Should observed inputs be input to model via encoders or injected directly into generator
cell_weight_scale: Input scaling for input weights in generator
Ic_dim: Dimension of h0
factors_dim: Number of factors from generator
ic_enc_dim: Cell hidden size, encoder of h0
gen_dim: Cell hidden size, generator
gen_cell_input_weight_scale: Input scaling for input weights in generator
gen_cell_rec_weight_scale: Input scaling for rec weights in generator
ic_prior_var_min: Min variance in posterior h0 codes
ic_prior_var_scale: Variance of ic prior distribution
ic_prior_var_max: Max variance of IC prior distribution
ic_post_var_min: Min variance of IC posterior distribution
co_prior_var_scale: Variance of control input prior distribution
prior_ar_atau: Initial autocorrelation of AR(1) priors
Prior_ar_nvar: Initial noise variance for AR(1) priors
do_train_prior_ar_atau_: Is the value for atau an init, or the constant value?
do_train_prior_ar_nvar: Is the value for noise variance an init, or constant value?
do_causal_controller: Restrict the controller create only causal inferred inputs?
controller_input_lag: time lag on the encoding to controller t-lag for forward, t+lag for reverse
do_feed_factors_to_controller: Should factors [t-1] be input to controller at time t?
feedback_factors_or_rates: feedback the factors or the rates to the controller?
co_dim: Number of control net outputs (>0 builds that graph)
ci_enc_dim: cell hidden size, encoder of control inputs
con_dim: cell hidden size, controller
co_mean_corr_scale: Cost of correlation in the means of controller output
batch_size: batch size to use during training
learning_rate_init: learning rate initial value
learning_rate_decay_factor: learning rate decay, decay by this fraction every so often
learning_rate_stop: the lr is adaptively reduced, stop training at this value
learning_rate_n_to_compare: the lr is adaptively reduced, stop training at this value
max_grad_norm: max norm of gradient before clipping
cell_clip_value: max value recurrent cell can take before begin clipped
do_train_io_only: train only the input and output affine functions
do_train_encoder_only: train only the encoder weights
do_reset_learning_rate: reset the learning rate to initial value
keep_prob: dropout keep probability
temporal_spike_jitter_width: shuffle spikes around this window
l2_gen_scale: L2 regularization cost for the generator only
l2_con_scale: L2 regularization cost for the controller only
kl_ic_weight: strength of KL weight on initial conditions KL penalty
kl_co_weight: strength of KL weight on controller output KL penalty
kl_start_step: start increasing weight after this many steps
kl_increase_steps: increase weight of kl cost to avoid local minimum
l2_start_step: start increasing l2 weight after this many steps
l2_increase_steps: increase weight of l2 cost to avoid local minimum
_clip_value: 80
Outputs:
LFADS components: denoised version of the dataset
You must login to use an analysis.