scmiracle.model#
- class scmiracle.model.Decoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#
Bases:
ModuleDecoder class for multi-modal data with shared and modality-specific decoding layers.
- Parameters:
dims_x – Dict[str, list] Output dimensions for each modality.
dims_h – Dict[str, list] Hidden dimensions for each modality.
dim_z – int Latent dimension size.
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
out_trans – str Output activation function (e.g., ‘relu’).
drop – float Dropout rate.
kwargs – Dict[str, Any] Additional modality-specific configurations.
- class scmiracle.model.Discriminator(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#
Bases:
ModuleDiscriminator class for multi-modal latent variables.
- Parameters:
dims_x – Dict[str, list] Input dimensions for each modality.
dims_s – Dict[str, int] Dimensions of the classes for each modality.
kwargs – Dict[str, Any] Additional configurations, such as hidden layer sizes, dropout rate, and normalization type.
- calculate_loss(predictions: Dict[str, Tensor], targets: Dict[str, Tensor]) Tensor[source]#
Calculate cross-entropy loss for all modalities.
- Parameters:
predictions – Dict[str, torch.Tensor] Dictionary of predicted logits for each modality.
targets – Dict[str, torch.Tensor] Dictionary of ground truth labels for each modality.
- Returns:
Total normalized loss.
- Return type:
torch.Tensor
- forward(latent_inputs: Dict[str, Tensor]) Dict[str, Tensor][source]#
Forward pass for the discriminator.
- Parameters:
latent_inputs – Dict[str, torch.Tensor] Dictionary of latent inputs for each modality, where keys are modality names and values are tensors of shape (batch_size, dim_c).
- Returns:
Dictionary of logits for each modality, where keys are modality names and values are tensors of shape (batch_size, dims_s[modality]).
- Return type:
Dict[str, torch.Tensor]
- class scmiracle.model.Encoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#
Bases:
ModuleEncoder class for multi-modal data with modality-specific pre-processing, encoding, and shared encoding layers.
- Parameters:
dims_x – Dict[str, list] Input dimensions for each modality (e.g, {‘rna’:[1000], ‘adt’:[100]}).
dims_h – Dict[str, list] Hidden dimensions for each modality after pre-encoding (e.g, {‘rna’:256, ‘adt’:256}).
dim_z – int Latent dimension size (e.g, 32).
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
out_trans – str Output activation function (e.g., ‘mish’).
drop – float Dropout rate.
kwargs – Dict[str, Any] Additional modality-specific configurations.
Notes
By default, RNA and ADT data are log1p-transformed in the encoder and will be exponentiated after decoding. To skip this step, modify the configuration file. See parameter ‘trsf_before_enc_’.
- forward(data: Dict[str, Tensor], mask: Dict[str, Tensor]) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#
Forward pass for the encoder.
- Parameters:
data – Dict[str, torch.Tensor] Input data for each modality.
mask – Dict[str, torch.Tensor] Masks for each modality.
- Returns:
- z_x_muDict[str, torch.Tensor]
Mean values for latent space for each modality.
- z_x_logvarDict[str, torch.Tensor]
Log-variance values for latent space for each modality.
- Return type:
Tuple
- class scmiracle.model.MIDAS[source]#
Bases:
LightningModuleMIDAS processes mosaic single-cell data into imputed and batch-corrected data for multimodal analysis.
- net#
VAE Variational Autoencoder for multi-modal data encoding and decoding.
- dsc#
Discriminator Discriminator for distinguishing latent variables across batches.
- configs#
Dict[str, Any] Model and training configurations dynamically set as attributes.
- automatic_optimization#
bool Controls whether optimization is automatic or manually defined. Always True.
- static calc_consistency_loss(z_uni: Dict[str, Tensor]) Tensor[source]#
Calculate the consistency loss for unified latent variables across modalities.
- Parameters:
z_uni – Dict[str, torch.Tensor] Dictionary of unified latent variables for each modality, where each value is a tensor of shape (batch_size x latent_dim).
- Returns:
Consistency loss computed as the variance of the unified latent variables.
- Return type:
torch.Tensor
- static calc_dsc_loss(pred: Dict[str, Tensor], true: Dict[str, Tensor]) Tensor[source]#
Calculate the discriminator loss using cross-entropy.
- Parameters:
pred – Dict[str, torch.Tensor] Predicted logits for each modality.
true – Dict[str, torch.Tensor] Ground truth labels for each modality.
- Returns:
Computed discriminator loss.
- Return type:
torch.Tensor
- static calc_kld_loss(mu: Tensor, logvar: Tensor) Tensor[source]#
Calculate the KLD loss for a single latent space.
- Parameters:
mu – torch.Tensor Mean of the latent variable distribution (batch_size x latent_dim).
logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x latent_dim).
- Returns:
KLD loss for the latent space, normalized by batch size.
- Return type:
torch.Tensor
- static calc_kld_z_loss(dim_c: int, dim_u: int, lam_kld_c: float, lam_kld_u: float, mu: Tensor, logvar: Tensor) Tensor[source]#
Calculate the Kullback-Leibler Divergence (KLD) loss for latent variables z.
- Parameters:
dim_c – int Dimension of the biological latent space.
dim_u – int Dimension of the technical latent space.
lam_kld_c – float Weight for KLD loss of the biological latent space.
lam_kld_u – float Weight for KLD loss of the technical latent space.
mu – torch.Tensor Mean of the latent variable distribution (batch_size x (dim_c + dim_u)).
logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x (dim_c + dim_u)).
- Returns:
Weighted sum of KLD losses for the biological and technical latent spaces.
- Return type:
torch.Tensor
- static calc_recon_loss(x: Dict[str, Tensor], s: Tensor, e: Dict[str, Tensor], x_r_pre: Dict[str, Tensor], s_r_pre: Dict[str, Tensor], dist: Dict[str, str], lam: Dict[str, float]) Tuple[float, Dict[Tensor, Tensor]][source]#
Calculate the reconstruction loss for input data and predicted outputs.
- Parameters:
x – Dict[str, torch.Tensor] Original input data for each modality (x^m).
s – torch.Tensor Ground truth batch labels.
e – Dict[str, torch.Tensor] Mask.
x_r_pre – Dict[str, torch.Tensor] Reconstructed predictions for each modality (x_r^m).
s_r_pre – Dict[str, torch.Tensor] Reconstructed predictions for batch labels.
dist – Dict[str, str] Dictionary specifying the distribution type for each modality’s decoder.
lam – Dict[str, float] Dictionary containing reconstruction loss weights for each modality and for s.
- Returns:
- total_losstorch.Tensor
Total reconstruction loss, normalized by batch size.
- lossesDict[str, torch.Tensor]
Dictionary containing reconstruction losses for each modality and for batch labels.
- Return type:
Tuple
- classmethod configure_data(configs: dict, datalist: List[Dataset], dims_x: Dict[str, list], dims_s: Dict[str, int], s_joint: List[Dict[str, int]], combs: List[List[str]], batch_size: int = 256, n_save: int = 500, save_model_path: str = './saved_models/', sampler_type: str = 'auto', viz_umap_tb=False, batch_names=None) MIDAS[source]#
Configure the data and model parameters for training.
- Parameters:
configs – dict, Configurations of the model.
datalist – List[Dataset] List of datasets to be used for training.
dims_x – Dict[str, list] Dictionary specifying the dimensions of input features for each modality.
dims_s – Dict[str, int] Dimensions of the classes for each modality.
s_joint – List[Dict[str, int]] Modality ID for each batch.
combs – List[List[str]] Combinations of modalities.
batch_size – int, optional Size of each training batch, by default 256.
n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.
save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.
sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.
viz_umap_tb – bool, optional Whether to visualize UMAP embeddings in TensorBoard, by default False.
batch_names – list, optional List of batch names, by default None.
- Returns:
Returns MIDAS instance.
- Return type:
class ‘MIDAS’
- classmethod configure_data_from_dir(configs: Dict[str, Any], dir_path: str, format: str = 'mtx', transform: Dict[str, str] = None, sampler_type: str = 'auto', viz_umap_tb: bool = False, **kwargs: Dict[str, Any]) MIDAS[source]#
Configure data from a directory and apply optional transformations.
- Parameters:
configs – Dict[str, Any] Configurations of the model.
dir_path – str Path to the directory containing data files.
transform – Dict[str, str], optional A dictionary specifying transformations to apply to specific modalities. Example: {‘atac’: ‘binarize’} Default is None, which uses the default transformation settings.
sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.
kwargs – Dict[str, Any] Additional parameters passed to configure_data().
- Returns:
Returns the configured class instance.
- Return type:
class ‘MIDAS’
Examples
>>> from scmidas.model import MIDAS >>> from scmidas.config import load_config >>> configs = load_config() >>> dir_path = 'XXX' >>> transform = {'atac': 'binarize'} >>> model = MIDAS.configure_data_from_dir(configs, dir_path, transform)
- classmethod configure_new_data_from_dir(configs: Dict[str, Any], dir_path: str, format: str = 'mtx', transform: Dict[str, str] = None, scale=None, viz_umap_tb: bool = False, **kwargs: Dict[str, Any]) MIDAS[source]#
- configure_optimizers() List[Optimizer][source]#
Configure optimizers for the MIDAS model.
- Returns:
List of optimizers for the network and discriminator.
- Return type:
List[torch.optim.Optimizer]
- classmethod get_datasets_from_dir(data: List[Dict[str, str]], mask: List[Dict[str, str]], transform: Dict[str, str] = None, format: str = 'mtx')[source]#
Configure data from a CSV input.
- Parameters:
data – List[Dict[str, str]] List of data dictionaries, where keys are modalities and values are file paths.
mask – List[Dict[str, str]] List of mask dictionaries, where keys are modalities and values are mask file paths.
transform – Optional[Dict[str, str]] Transformations to apply to specific modalities.
format – str File type of the input data, default is ‘vec’. [‘vec’, ‘mtx’, ‘csv’]
- Returns:
- datasetsList[MultiModalDataset]
List of initialized MultiModalDataset objects.
- dims_sDict[str, int]
Dimensions for batch correction for each modality.
- s_jointList[Dict[str, int]]
Modality indices for each batch.
- combsList[List[str]]
List of modality combinations for each batch.
- Return type:
Tuple
- get_emb_umap(pred_dir: str = None, save_dir: str = None, use_mtx: bool = True, drop_c_umap: bool = False, drop_u_umap: bool = False, **kwargs) Tuple[List[AnnData], List[Figure]][source]#
Generate UMAP embeddings for biological (c) and technical (u) latent variables.
- Parameters:
pred_dir – str Directory containing predicted data.
save_dir – str, optional Directory to save UMAP plots, by default ‘./’.
use_mtx – bool, optional Whether to load embeddings of mtx format, by default True.
drop_c_umap – bool, optional Whether to drop the biological embedding (c) from UMAP, by default False.
drop_u_umap – bool, optional Whether to drop the technical embedding (u) from UMAP, by default False.
kwargs – Dict[str, Any] Additional configurations for sc.pl.umap().
- Returns:
- all_adataList[AnnData]
List of AnnData objects containing UMAP embeddings for each modality.
- all_figuresList[plt.Figure]
List of UMAP figures for biological and technical embeddings.
- Return type:
Tuple
- static get_info_from_dir(dir_path: str, format: str)[source]#
Extract data, mask, and feature dimensions from a directory of vectors.
- Parameters:
dir_path – str Path to the directory containing data and mask files.
format – str Support ‘mtx’, ‘csv’, and ‘vec’.
- Returns:
- dataList[Dict[str, str]]
List of dictionaries where keys are modalities and values are file paths.
- maskList[Dict[str, str]]
List of dictionaries where keys are modalities and values are mask file paths.
- dims_xDict[str, list]
Dictionary containing feature dimensions for each modality.
- Return type:
Tuple
Notes
The directory should be organized as:
dataset/ feat/ # Dimensions of each modality: {mod1=[...], mod2=[...]}. # Split the data into chunks if the length of the list is greater than 1. # For instance, you can split the ATAC data by chromosomes. feat_dims.toml batch_0/ mask/mod1.csv mask/mod2.csv vec/mod1/0000.csv # the first sample vec/mod1/0001.csv # the second sample .... vec/mod2/0000.csv vec/mod2/0001.csv .... batch_1/ mask/mod1.csv mask/mod2.csv vec/mod1/0000.csv vec/mod1/0001.csv .... vec/mod2/0000.csv vec/mod2/0001.csv ....
or like:
dataset/ feat/ # Dimensions of each modality: {mod1=[...], mod2=[...]}. # Split the data into chunks if the length of the list is greater than 1. # For instance, you can split the ATAC data by chromosomes. feat_dims.toml batch_0/ mask/mod1.csv mask/mod2.csv mat/mod1.mtx (.csv) mat/mod2.mtx (.csv) .... batch_1/ mask/mod1.csv mask/mod2.csv mat/mod1.mtx (.csv) mat/mod2.mtx (.csv) ....
- load_checkpoint(checkpoint_path: str, start_epoch: int = 0, **kwargs)[source]#
Load model and optimizer states from a checkpoint file.
- Parameters:
checkpoint_path – str Path to the checkpoint file containing saved model and optimizer states.
start_epoch – int Indicate how many epoch the model has been trained.
kwargs – Dict[str, Any] Additional configurations for torch.load().
- Raises:
AssertionError – If the provided checkpoint path does not exist.
- log_losses(recon_loss: Tensor, kld_loss, consistency_loss: Tensor, loss_net: Tensor, loss_dsc: Tensor, recon_dict: Dict[str, Tensor])[source]#
Log losses for monitoring and debugging during training.
- Parameters:
recon_loss – torch.Tensor Reconstruction loss.
kld_loss – torch.Tensor KLD loss.
consistency_loss – torch.Tensor Consistency loss.
recon_dict – Dict[str, torch.Tensor] Per-modality reconstruction losses.
loss_net – torch.Tensor Total VAE loss.
loss_dsc – torch.Tensor Discriminator loss.
- on_train_epoch_end()[source]#
Save a model checkpoint at the end of each training epoch with a meaningful filename.
- pack_data(des_dir, subsample_num: list = None, subsample='random', format='mtx', use_transform: bool = False, return_idx=True)[source]#
Packs data from datalist into a specified directory structure, with an option for proportional subsampling.
- predict(return_format='array', save_dir: str = None, save_format='h5ad', joint_latent: bool = True, mod_latent: bool = False, impute: bool = False, batch_correct: bool = False, translate: bool = False, input: bool = False) Dict[str, Dict[str, Tensor]] | None[source]#
Run model inference to generate latent embeddings, reconstructed data, or translated modalities.
The method iterates through the dataset, computes representations using the trained network, and optionally saves the results to disk or returns them in memory.
- Parameters:
return_format (str, default='array') – The format of the returned data. - ‘array’: Returns a nested dictionary of Numpy arrays. - ‘anndata’: Returns a dictionary of AnnData objects (requires scanpy).
save_dir (str, optional, default=None) – Directory path to save the prediction results. If None, results are not saved to disk.
save_format (str, default='h5ad') – File format for saving results if save_dir is provided. - ‘h5ad’: Saves as H5AD files (one per batch). - ‘mtx’: Saves as Matrix Market files inside batch-specific folders.
joint_latent (bool, default=True) – If True, computes the joint latent representation (z) conditioned on all observed modalities.
mod_latent (bool, default=False) – If True, computes latent representations conditioned on each individual modality. (Automatically set to True if translate is True).
impute (bool, default=False) – If True, generates imputed data (x_impt) from the joint latent space.
batch_correct (bool, default=False) – If True, calculates batch centroids and performs batch-effect correction on the reconstructed data (x_bc). Note: This involves a second pass through the data.
translate (bool, default=False) – If True, performs cross-modality translation (e.g., generating Modality B from Modality A).
input (bool, default=False) – If True, includes the original input raw data and masks in the output.
- Returns:
output – A dictionary where keys are batch names.
If return_format=’array’: {
- ’batch_name’: {
‘z_c’: {‘joint’: np.ndarray, ‘modality_name’: …}, # Content latent ‘z_u’: {‘joint’: np.ndarray, …}, # Batch latent ‘x_impt’: {‘modality_name’: np.ndarray}, # Imputed data ‘x_bc’: {‘modality_name’: np.ndarray}, # Batch corrected data …
}
If return_format=’anndata’: {
’batch_name’: ann_data_object
} (Latent variables are stored in adata.obsm, masks in adata.uns).
- Return type:
Dict
- static print_info(mask: List[Dict[str, str]], datalist: List[Dataset], batch_names: List[str])[source]#
Print summary of mask density and dataset information.
- Parameters:
mask – List[Dict[str, str]] List of mask.
datalist – List[Dataset] List of datasets.
batch_name – List[str] List of batch names.
- save_checkpoint(checkpoint_path: str)[source]#
Save the current model and optimizer states to a checkpoint file.
- Parameters:
checkpoint_path – str Path to save the checkpoint file.
- Raises:
ValueError – If checkpoint_path is an invalid or empty string.
- train_dataloader() DataLoader[source]#
Create a DataLoader for training, using the appropriate sampler.
- Returns:
Configured DataLoader instance for training.
- Return type:
DataLoader
- train_discriminator(c_all: Dict[str, Tensor], targets: Dict[str, Tensor], scale: float = 1.0) None[source]#
Train the discriminator with modality-specific latent representations.
- Parameters:
c_all – Dict[str, torch.Tensor] Dictionary of latent representations for each modality.
targets – Dict[str, torch.Tensor] Ground truth batch labels for each modality.
- training_step(batch: Dict[str, Dict[str, Tensor]], batch_idx: int) Tensor[source]#
Executes a single training step for MIDAS.
- Parameters:
batch – Dict[str, Dict[str, torch.Tensor]] Input batch containing modality data, batch indices, and masks.
batch_idx – int Index of the current training batch.
- Returns:
Total VAE loss for the current batch.
- Return type:
torch.Tensor
- update(model2: Module)[source]#
Align and update weights from model2 to model1 by left-aligning the weights.
- Parameters:
model1 (nn.Module) – target model to be updated.
model2 (nn.Module) – source model from which weights are taken.
- static update_model(loss: Tensor, model: Module, optimizer: Optimizer, grad_clip: int = -1)[source]#
Update model parameters using backpropagation.
- Parameters:
loss – torch.Tensor Computed loss for backpropagation.
model – torch.nn.Module Model to update.
optimizer – torch.optim.Optimizer Optimizer for parameter updates.
grad_clip – int True to allow clipping gradient.
- class scmiracle.model.S_Decoder(n_batches: int, dims_dec_s: List[int], dim_u: int, norm: str, drop: float)[source]#
Bases:
ModuleDecoder for reconstructing batch ID.
- Parameters:
n_batches – int Number of distinct batches.
dims_dec_s – List[int] List of dimensions for hidden layers in the decoder.
dim_u – int Latent dimension size for the input (e.g, 2).
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
drop – float Dropout rate.
- class scmiracle.model.S_Encoder(n_batches: int, dims_enc_s: List[int], dim_z: int, norm: str, drop: float)[source]#
Bases:
ModuleEncoder for batch ID latent variables.
- Parameters:
n_batches – int Number of distinct batches.
dims_enc_s – List[int] List of dimensions for hidden layers in the encoder.
dim_z – int Latent dimension size for the latent.
norm – str Normalization type (e.g., ‘ln’ for LayerNorm).
drop – float Dropout rate.
- class scmiracle.model.VAE(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#
Bases:
ModuleVariational Autoencoder (VAE) for multi-modal data, supporting batch correction and sampling from distributions.
- Parameters:
dims_x – Dict[str, list] Input dimensions for each modality, e.g {‘rna’=[1000], ‘adt’=[100], ‘atac’=[10,10,10]}.
dims_s – Dict[str, int] Dimensions of the classes for each modality.
kwargs – Dict[str, Any] Additional configurations for encoders, decoders, and other modules.
- encode_batch(s: Tensor) Tuple[list, list] | None[source]#
Encode batch IDs latent variables.
- Parameters:
s – torch.Tensor Batch IDs.
- Returns:
- z_s_muList[torch.Tensor]
Mean of batch IDs latent variables.
- z_s_logvarList[torch.Tensor]
Log-variance of batch IDs latent variables.
- Return type:
Optional[Tuple[list, list]]
- forward(data: Dict[str, Tensor]) Tuple[Dict[str, Tensor], Tensor | None, Tensor, Tensor, Tensor, Tensor, Tensor, Dict[str, Tensor], Dict[str, Tensor]][source]#
Forward pass for the VAE.
- Parameters:
data – Dict[str, torch.Tensor] Input data dictionary containing: - ‘x’: Dict[str, torch.Tensor], modality-specific input data. - ‘e’: Dict[str, torch.Tensor], modality-specific masks. - ‘s’ (optional): torch.Tensor, dimensions of the output classes for each modality.
- Returns:
- x_r_preDict[str, torch.Tensor]
Reconstructed modality-specific data.
- s_r_preOptional[torch.Tensor]
If ‘s’ is provided, return reconstructed batch indices. If ‘s’ is not provided, return None.
- z_mutorch.Tensor
Mean of the combined latent variables.
- z_logvartorch.Tensor
Log-variance of the combined latent variables.
- ztorch.Tensor
Sampled latent variables.
- ctorch.Tensor
Biological information variables.
- utorch.Tensor
Technical noise variables.
- z_uniDict[str, torch.Tensor]
Unified latent variables for each modality.
- c_allDict[str, torch.Tensor]
Modality-specific Biological information variables.
- Return type:
Tuple
- gen_real_data(x_r_pre: Dict[str, Tensor], sampling: bool = True) Dict[str, Tensor][source]#
Generate real data from reconstructed data.
- Parameters:
x_r_pre – Dict[str, torch.Tensor] Dictionary of reconstructed data tensors for each modality.
sampling – bool, optional Whether to sample the output (default: True).
- Returns:
Generated real data for each modality.
- Return type:
Dict[str, torch.Tensor]
- generate_unified_latent(z_x_mu: Dict[str, Tensor], z_x_logvar: Dict[str, Tensor], z_s_mu: List[Tensor], z_s_logvar: List[Tensor], c: Tensor) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#
Generate unified latent variables and modality-specific representations.
- Parameters:
z_x_mu – Dict[str, torch.Tensor] Means of modality-specific latent variables.
z_x_logvar – Dict[str, torch.Tensor] Log-variances of modality-specific latent variables.
z_s_mu – List[torch.Tensor] Mean of the batch-ID latent variables.
z_s_logvar – List[torch.Tensor] Log-variance of the batch-ID latent variables.
c – torch.Tensor Biological information.
- Returns:
- z_uniDict[str, torch.Tensor]:
Collection of latent variables for the unimodal inputs.
- c_allDict[str, torch.Tensor]:
Collection of biological information for the unimodal and joint inputs.
- Return type:
Tuple
- get_dim_h() Dict[str, List[int]][source]#
Compute hidden dimensions for each modality.
- Returns:
A dictionary containing the hidden dimensions for each modality.
- Return type:
Dict[str, List[int]]
- static poe(mus: List[Tensor], logvars: List[Tensor]) Tuple[Tensor, Tensor][source]#
Product of Experts (PoE) for combining Gaussian distributions.
- Parameters:
mus – list of torch.Tensor List of mean tensors for each Gaussian.
logvars – list of torch.Tensor List of log-variance tensors for each Gaussian.
- Returns:
- combined_mean: torch.Tensor
Mean of the combined Gaussian distribution.
- combined_logvar: torch.Tensor
Log-variance of the combined Gaussian distribution.
- Return type:
Tuple
- static sample(name: str, data: Tensor, sampling: bool) Tensor[source]#
Map a sampling function based on the distribution name.
- Parameters:
name – str Name of the distribution.
data – torch.Tensor Input data tensor.
sampling – bool Whether to apply sampling.
- Returns:
- torch.Tensor
Sampled or original data tensor.
- static sample_gaussian(mu: Tensor, logvar: Tensor) Tensor[source]#
Sample from a Gaussian distribution using the reparameterization trick.
- Parameters:
mu – torch.Tensor Mean of the Gaussian distribution.
logvar – torch.Tensor Log-variance of the Gaussian distribution.
- Returns:
- torch.Tensor
Sampled tensor.
- sample_latent(z_mu: Tensor, z_logvar: Tensor) Tensor[source]#
Sample latent variables from a Gaussian distribution.
- Parameters:
z_mu – torch.Tensor Mean of the latent variables of shape (batch_size, latent_dim).
z_logvar – torch.Tensor Log-variance of the latent variables of shape (batch_size, latent_dim).
- Returns:
Sampled latent variables of shape (batch_size, latent_dim).
- Return type:
torch.Tensor