scmiracle.model#

class scmiracle.model.Decoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#

Bases: Module

Decoder class for multi-modal data with shared and modality-specific decoding layers.

Parameters:
  • dims_x – Dict[str, list] Output dimensions for each modality.

  • dims_h – Dict[str, list] Hidden dimensions for each modality.

  • dim_z – int Latent dimension size.

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • out_trans – str Output activation function (e.g., ‘relu’).

  • drop – float Dropout rate.

  • kwargs – Dict[str, Any] Additional modality-specific configurations.

forward(latent_data: Tensor) Dict[str, Tensor][source]#

Forward pass for the decoder.

Parameters:

latent_data – torch.Tensor Latent variable input tensor of shape (batch_size, dim_z).

Returns:

Decoded outputs for each modality.

Return type:

Dict[str, torch.Tensor]

class scmiracle.model.Discriminator(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#

Bases: Module

Discriminator class for multi-modal latent variables.

Parameters:
  • dims_x – Dict[str, list] Input dimensions for each modality.

  • dims_s – Dict[str, int] Dimensions of the classes for each modality.

  • kwargs – Dict[str, Any] Additional configurations, such as hidden layer sizes, dropout rate, and normalization type.

calculate_loss(predictions: Dict[str, Tensor], targets: Dict[str, Tensor]) Tensor[source]#

Calculate cross-entropy loss for all modalities.

Parameters:
  • predictions – Dict[str, torch.Tensor] Dictionary of predicted logits for each modality.

  • targets – Dict[str, torch.Tensor] Dictionary of ground truth labels for each modality.

Returns:

Total normalized loss.

Return type:

torch.Tensor

forward(latent_inputs: Dict[str, Tensor]) Dict[str, Tensor][source]#

Forward pass for the discriminator.

Parameters:

latent_inputs – Dict[str, torch.Tensor] Dictionary of latent inputs for each modality, where keys are modality names and values are tensors of shape (batch_size, dim_c).

Returns:

Dictionary of logits for each modality, where keys are modality names and values are tensors of shape (batch_size, dims_s[modality]).

Return type:

Dict[str, torch.Tensor]

class scmiracle.model.Encoder(dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs)[source]#

Bases: Module

Encoder class for multi-modal data with modality-specific pre-processing, encoding, and shared encoding layers.

Parameters:
  • dims_x – Dict[str, list] Input dimensions for each modality (e.g, {‘rna’:[1000], ‘adt’:[100]}).

  • dims_h – Dict[str, list] Hidden dimensions for each modality after pre-encoding (e.g, {‘rna’:256, ‘adt’:256}).

  • dim_z – int Latent dimension size (e.g, 32).

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • out_trans – str Output activation function (e.g., ‘mish’).

  • drop – float Dropout rate.

  • kwargs – Dict[str, Any] Additional modality-specific configurations.

Notes

By default, RNA and ADT data are log1p-transformed in the encoder and will be exponentiated after decoding. To skip this step, modify the configuration file. See parameter ‘trsf_before_enc_’.

forward(data: Dict[str, Tensor], mask: Dict[str, Tensor]) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#

Forward pass for the encoder.

Parameters:
  • data – Dict[str, torch.Tensor] Input data for each modality.

  • mask – Dict[str, torch.Tensor] Masks for each modality.

Returns:

  • z_x_muDict[str, torch.Tensor]

    Mean values for latent space for each modality.

  • z_x_logvarDict[str, torch.Tensor]

    Log-variance values for latent space for each modality.

Return type:

Tuple

class scmiracle.model.MIDAS[source]#

Bases: LightningModule

MIDAS processes mosaic single-cell data into imputed and batch-corrected data for multimodal analysis.

net#

VAE Variational Autoencoder for multi-modal data encoding and decoding.

dsc#

Discriminator Discriminator for distinguishing latent variables across batches.

configs#

Dict[str, Any] Model and training configurations dynamically set as attributes.

automatic_optimization#

bool Controls whether optimization is automatic or manually defined. Always True.

static calc_consistency_loss(z_uni: Dict[str, Tensor]) Tensor[source]#

Calculate the consistency loss for unified latent variables across modalities.

Parameters:

z_uni – Dict[str, torch.Tensor] Dictionary of unified latent variables for each modality, where each value is a tensor of shape (batch_size x latent_dim).

Returns:

Consistency loss computed as the variance of the unified latent variables.

Return type:

torch.Tensor

static calc_dsc_loss(pred: Dict[str, Tensor], true: Dict[str, Tensor]) Tensor[source]#

Calculate the discriminator loss using cross-entropy.

Parameters:
  • pred – Dict[str, torch.Tensor] Predicted logits for each modality.

  • true – Dict[str, torch.Tensor] Ground truth labels for each modality.

Returns:

Computed discriminator loss.

Return type:

torch.Tensor

static calc_kld_loss(mu: Tensor, logvar: Tensor) Tensor[source]#

Calculate the KLD loss for a single latent space.

Parameters:
  • mu – torch.Tensor Mean of the latent variable distribution (batch_size x latent_dim).

  • logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x latent_dim).

Returns:

KLD loss for the latent space, normalized by batch size.

Return type:

torch.Tensor

static calc_kld_z_loss(dim_c: int, dim_u: int, lam_kld_c: float, lam_kld_u: float, mu: Tensor, logvar: Tensor) Tensor[source]#

Calculate the Kullback-Leibler Divergence (KLD) loss for latent variables z.

Parameters:
  • dim_c – int Dimension of the biological latent space.

  • dim_u – int Dimension of the technical latent space.

  • lam_kld_c – float Weight for KLD loss of the biological latent space.

  • lam_kld_u – float Weight for KLD loss of the technical latent space.

  • mu – torch.Tensor Mean of the latent variable distribution (batch_size x (dim_c + dim_u)).

  • logvar – torch.Tensor Log-variance of the latent variable distribution (batch_size x (dim_c + dim_u)).

Returns:

Weighted sum of KLD losses for the biological and technical latent spaces.

Return type:

torch.Tensor

static calc_recon_loss(x: Dict[str, Tensor], s: Tensor, e: Dict[str, Tensor], x_r_pre: Dict[str, Tensor], s_r_pre: Dict[str, Tensor], dist: Dict[str, str], lam: Dict[str, float]) Tuple[float, Dict[Tensor, Tensor]][source]#

Calculate the reconstruction loss for input data and predicted outputs.

Parameters:
  • x – Dict[str, torch.Tensor] Original input data for each modality (x^m).

  • s – torch.Tensor Ground truth batch labels.

  • e – Dict[str, torch.Tensor] Mask.

  • x_r_pre – Dict[str, torch.Tensor] Reconstructed predictions for each modality (x_r^m).

  • s_r_pre – Dict[str, torch.Tensor] Reconstructed predictions for batch labels.

  • dist – Dict[str, str] Dictionary specifying the distribution type for each modality’s decoder.

  • lam – Dict[str, float] Dictionary containing reconstruction loss weights for each modality and for s.

Returns:

  • total_losstorch.Tensor

    Total reconstruction loss, normalized by batch size.

  • lossesDict[str, torch.Tensor]

    Dictionary containing reconstruction losses for each modality and for batch labels.

Return type:

Tuple

classmethod configure_data(configs: dict, datalist: List[Dataset], dims_x: Dict[str, list], dims_s: Dict[str, int], s_joint: List[Dict[str, int]], combs: List[List[str]], batch_size: int = 256, n_save: int = 500, save_model_path: str = './saved_models/', sampler_type: str = 'auto', viz_umap_tb=False, batch_names=None) MIDAS[source]#

Configure the data and model parameters for training.

Parameters:
  • configs – dict, Configurations of the model.

  • datalist – List[Dataset] List of datasets to be used for training.

  • dims_x – Dict[str, list] Dictionary specifying the dimensions of input features for each modality.

  • dims_s – Dict[str, int] Dimensions of the classes for each modality.

  • s_joint – List[Dict[str, int]] Modality ID for each batch.

  • combs – List[List[str]] Combinations of modalities.

  • batch_size – int, optional Size of each training batch, by default 256.

  • n_save – int, optional Interval (in epochs) for saving model checkpoints, by default 500.

  • save_model_path – str, optional Directory path for saving model checkpoints, by default ‘./saved_models/’.

  • sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.

  • viz_umap_tb – bool, optional Whether to visualize UMAP embeddings in TensorBoard, by default False.

  • batch_names – list, optional List of batch names, by default None.

Returns:

Returns MIDAS instance.

Return type:

class ‘MIDAS’

classmethod configure_data_from_dir(configs: Dict[str, Any], dir_path: str, format: str = 'mtx', transform: Dict[str, str] = None, sampler_type: str = 'auto', viz_umap_tb: bool = False, **kwargs: Dict[str, Any]) MIDAS[source]#

Configure data from a directory and apply optional transformations.

Parameters:
  • configs – Dict[str, Any] Configurations of the model.

  • dir_path – str Path to the directory containing data files.

  • transform – Dict[str, str], optional A dictionary specifying transformations to apply to specific modalities. Example: {‘atac’: ‘binarize’} Default is None, which uses the default transformation settings.

  • sampler_type – str, optional Type of sampler to use, by default ‘auto’. For ‘ddp’, use distributed sampler.

  • kwargs – Dict[str, Any] Additional parameters passed to configure_data().

Returns:

Returns the configured class instance.

Return type:

class ‘MIDAS’

Examples

>>> from scmidas.model import MIDAS
>>> from scmidas.config import load_config
>>> configs = load_config()
>>> dir_path = 'XXX'
>>> transform = {'atac': 'binarize'}
>>> model = MIDAS.configure_data_from_dir(configs, dir_path, transform)
classmethod configure_new_data_from_dir(configs: Dict[str, Any], dir_path: str, format: str = 'mtx', transform: Dict[str, str] = None, scale=None, viz_umap_tb: bool = False, **kwargs: Dict[str, Any]) MIDAS[source]#
configure_optimizers() List[Optimizer][source]#

Configure optimizers for the MIDAS model.

Returns:

List of optimizers for the network and discriminator.

Return type:

List[torch.optim.Optimizer]

classmethod get_datasets_from_dir(data: List[Dict[str, str]], mask: List[Dict[str, str]], transform: Dict[str, str] = None, format: str = 'mtx')[source]#

Configure data from a CSV input.

Parameters:
  • data – List[Dict[str, str]] List of data dictionaries, where keys are modalities and values are file paths.

  • mask – List[Dict[str, str]] List of mask dictionaries, where keys are modalities and values are mask file paths.

  • transform – Optional[Dict[str, str]] Transformations to apply to specific modalities.

  • format – str File type of the input data, default is ‘vec’. [‘vec’, ‘mtx’, ‘csv’]

Returns:

  • datasetsList[MultiModalDataset]

    List of initialized MultiModalDataset objects.

  • dims_sDict[str, int]

    Dimensions for batch correction for each modality.

  • s_jointList[Dict[str, int]]

    Modality indices for each batch.

  • combsList[List[str]]

    List of modality combinations for each batch.

Return type:

Tuple

get_emb_umap(pred_dir: str = None, save_dir: str = None, use_mtx: bool = True, drop_c_umap: bool = False, drop_u_umap: bool = False, **kwargs) Tuple[List[AnnData], List[Figure]][source]#

Generate UMAP embeddings for biological (c) and technical (u) latent variables.

Parameters:
  • pred_dir – str Directory containing predicted data.

  • save_dir – str, optional Directory to save UMAP plots, by default ‘./’.

  • use_mtx – bool, optional Whether to load embeddings of mtx format, by default True.

  • drop_c_umap – bool, optional Whether to drop the biological embedding (c) from UMAP, by default False.

  • drop_u_umap – bool, optional Whether to drop the technical embedding (u) from UMAP, by default False.

  • kwargs – Dict[str, Any] Additional configurations for sc.pl.umap().

Returns:

all_adataList[AnnData]

List of AnnData objects containing UMAP embeddings for each modality.

all_figuresList[plt.Figure]

List of UMAP figures for biological and technical embeddings.

Return type:

Tuple

static get_info_from_dir(dir_path: str, format: str)[source]#

Extract data, mask, and feature dimensions from a directory of vectors.

Parameters:
  • dir_path – str Path to the directory containing data and mask files.

  • format – str Support ‘mtx’, ‘csv’, and ‘vec’.

Returns:

  • dataList[Dict[str, str]]

    List of dictionaries where keys are modalities and values are file paths.

  • maskList[Dict[str, str]]

    List of dictionaries where keys are modalities and values are mask file paths.

  • dims_xDict[str, list]

    Dictionary containing feature dimensions for each modality.

Return type:

Tuple

Notes

The directory should be organized as:

dataset/
    feat/
        # Dimensions of each modality: {mod1=[...], mod2=[...]}.
        # Split the data into chunks if the length of the list is greater than 1.
        # For instance, you can split the ATAC data by chromosomes.
        feat_dims.toml
    batch_0/
        mask/mod1.csv
        mask/mod2.csv
        vec/mod1/0000.csv # the first sample
        vec/mod1/0001.csv # the second sample
        ....
        vec/mod2/0000.csv
        vec/mod2/0001.csv
        ....
    batch_1/
        mask/mod1.csv
        mask/mod2.csv
        vec/mod1/0000.csv
        vec/mod1/0001.csv
        ....
        vec/mod2/0000.csv
        vec/mod2/0001.csv
    ....

or like:

dataset/
    feat/
        # Dimensions of each modality: {mod1=[...], mod2=[...]}.
        # Split the data into chunks if the length of the list is greater than 1.
        # For instance, you can split the ATAC data by chromosomes.
        feat_dims.toml
    batch_0/
        mask/mod1.csv
        mask/mod2.csv
        mat/mod1.mtx (.csv)
        mat/mod2.mtx (.csv)
        ....
    batch_1/
        mask/mod1.csv
        mask/mod2.csv
        mat/mod1.mtx (.csv)
        mat/mod2.mtx (.csv)
    ....
load_checkpoint(checkpoint_path: str, start_epoch: int = 0, **kwargs)[source]#

Load model and optimizer states from a checkpoint file.

Parameters:
  • checkpoint_path – str Path to the checkpoint file containing saved model and optimizer states.

  • start_epoch – int Indicate how many epoch the model has been trained.

  • kwargs – Dict[str, Any] Additional configurations for torch.load().

Raises:

AssertionError – If the provided checkpoint path does not exist.

log_losses(recon_loss: Tensor, kld_loss, consistency_loss: Tensor, loss_net: Tensor, loss_dsc: Tensor, recon_dict: Dict[str, Tensor])[source]#

Log losses for monitoring and debugging during training.

Parameters:
  • recon_loss – torch.Tensor Reconstruction loss.

  • kld_loss – torch.Tensor KLD loss.

  • consistency_loss – torch.Tensor Consistency loss.

  • recon_dict – Dict[str, torch.Tensor] Per-modality reconstruction losses.

  • loss_net – torch.Tensor Total VAE loss.

  • loss_dsc – torch.Tensor Discriminator loss.

on_train_end()[source]#

Save the final model checkpoint at the end of training.

on_train_epoch_end()[source]#

Save a model checkpoint at the end of each training epoch with a meaningful filename.

pack_data(des_dir, subsample_num: list = None, subsample='random', format='mtx', use_transform: bool = False, return_idx=True)[source]#

Packs data from datalist into a specified directory structure, with an option for proportional subsampling.

predict(return_format='array', save_dir: str = None, save_format='h5ad', joint_latent: bool = True, mod_latent: bool = False, impute: bool = False, batch_correct: bool = False, translate: bool = False, input: bool = False) Dict[str, Dict[str, Tensor]] | None[source]#

Run model inference to generate latent embeddings, reconstructed data, or translated modalities.

The method iterates through the dataset, computes representations using the trained network, and optionally saves the results to disk or returns them in memory.

Parameters:
  • return_format (str, default='array') – The format of the returned data. - ‘array’: Returns a nested dictionary of Numpy arrays. - ‘anndata’: Returns a dictionary of AnnData objects (requires scanpy).

  • save_dir (str, optional, default=None) – Directory path to save the prediction results. If None, results are not saved to disk.

  • save_format (str, default='h5ad') – File format for saving results if save_dir is provided. - ‘h5ad’: Saves as H5AD files (one per batch). - ‘mtx’: Saves as Matrix Market files inside batch-specific folders.

  • joint_latent (bool, default=True) – If True, computes the joint latent representation (z) conditioned on all observed modalities.

  • mod_latent (bool, default=False) – If True, computes latent representations conditioned on each individual modality. (Automatically set to True if translate is True).

  • impute (bool, default=False) – If True, generates imputed data (x_impt) from the joint latent space.

  • batch_correct (bool, default=False) – If True, calculates batch centroids and performs batch-effect correction on the reconstructed data (x_bc). Note: This involves a second pass through the data.

  • translate (bool, default=False) – If True, performs cross-modality translation (e.g., generating Modality B from Modality A).

  • input (bool, default=False) – If True, includes the original input raw data and masks in the output.

Returns:

output – A dictionary where keys are batch names.

  1. If return_format=’array’: {

    ’batch_name’: {

    ‘z_c’: {‘joint’: np.ndarray, ‘modality_name’: …}, # Content latent ‘z_u’: {‘joint’: np.ndarray, …}, # Batch latent ‘x_impt’: {‘modality_name’: np.ndarray}, # Imputed data ‘x_bc’: {‘modality_name’: np.ndarray}, # Batch corrected data …

    }

  2. If return_format=’anndata’: {

    ’batch_name’: ann_data_object

    } (Latent variables are stored in adata.obsm, masks in adata.uns).

Return type:

Dict

static print_info(mask: List[Dict[str, str]], datalist: List[Dataset], batch_names: List[str])[source]#

Print summary of mask density and dataset information.

Parameters:
  • mask – List[Dict[str, str]] List of mask.

  • datalist – List[Dataset] List of datasets.

  • batch_name – List[str] List of batch names.

classmethod reset()[source]#

Reset class-level attributes to their initial states.

save_checkpoint(checkpoint_path: str)[source]#

Save the current model and optimizer states to a checkpoint file.

Parameters:

checkpoint_path – str Path to save the checkpoint file.

Raises:

ValueError – If checkpoint_path is an invalid or empty string.

train_dataloader() DataLoader[source]#

Create a DataLoader for training, using the appropriate sampler.

Returns:

Configured DataLoader instance for training.

Return type:

DataLoader

train_discriminator(c_all: Dict[str, Tensor], targets: Dict[str, Tensor], scale: float = 1.0) None[source]#

Train the discriminator with modality-specific latent representations.

Parameters:
  • c_all – Dict[str, torch.Tensor] Dictionary of latent representations for each modality.

  • targets – Dict[str, torch.Tensor] Ground truth batch labels for each modality.

training_step(batch: Dict[str, Dict[str, Tensor]], batch_idx: int) Tensor[source]#

Executes a single training step for MIDAS.

Parameters:
  • batch – Dict[str, Dict[str, torch.Tensor]] Input batch containing modality data, batch indices, and masks.

  • batch_idx – int Index of the current training batch.

Returns:

Total VAE loss for the current batch.

Return type:

torch.Tensor

update(model2: Module)[source]#

Align and update weights from model2 to model1 by left-aligning the weights.

Parameters:
  • model1 (nn.Module) – target model to be updated.

  • model2 (nn.Module) – source model from which weights are taken.

static update_model(loss: Tensor, model: Module, optimizer: Optimizer, grad_clip: int = -1)[source]#

Update model parameters using backpropagation.

Parameters:
  • loss – torch.Tensor Computed loss for backpropagation.

  • model – torch.nn.Module Model to update.

  • optimizer – torch.optim.Optimizer Optimizer for parameter updates.

  • grad_clip – int True to allow clipping gradient.

class scmiracle.model.S_Decoder(n_batches: int, dims_dec_s: List[int], dim_u: int, norm: str, drop: float)[source]#

Bases: Module

Decoder for reconstructing batch ID.

Parameters:
  • n_batches – int Number of distinct batches.

  • dims_dec_s – List[int] List of dimensions for hidden layers in the decoder.

  • dim_u – int Latent dimension size for the input (e.g, 2).

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • drop – float Dropout rate.

forward(data: Tensor) Tensor[source]#

Forward pass for S_Decoder.

Parameters:

data – torch.Tensor Latent input tensor of shape (batch_size, dim_u).

Returns:

Reconstructed tensor of shape (batch_size, n_batches).

Return type:

torch.Tensor

class scmiracle.model.S_Encoder(n_batches: int, dims_enc_s: List[int], dim_z: int, norm: str, drop: float)[source]#

Bases: Module

Encoder for batch ID latent variables.

Parameters:
  • n_batches – int Number of distinct batches.

  • dims_enc_s – List[int] List of dimensions for hidden layers in the encoder.

  • dim_z – int Latent dimension size for the latent.

  • norm – str Normalization type (e.g., ‘ln’ for LayerNorm).

  • drop – float Dropout rate.

forward(data: Tensor) Tensor[source]#

Forward pass for S_Encoder.

Parameters:

data – torch.Tensor Input tensor of shape (batch_size, 1), containing batch indices.

Returns:

Encoded tensor of shape (batch_size, dim_z * 2).

Return type:

torch.Tensor

class scmiracle.model.VAE(dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs)[source]#

Bases: Module

Variational Autoencoder (VAE) for multi-modal data, supporting batch correction and sampling from distributions.

Parameters:
  • dims_x – Dict[str, list] Input dimensions for each modality, e.g {‘rna’=[1000], ‘adt’=[100], ‘atac’=[10,10,10]}.

  • dims_s – Dict[str, int] Dimensions of the classes for each modality.

  • kwargs – Dict[str, Any] Additional configurations for encoders, decoders, and other modules.

encode_batch(s: Tensor) Tuple[list, list] | None[source]#

Encode batch IDs latent variables.

Parameters:

s – torch.Tensor Batch IDs.

Returns:

  • z_s_muList[torch.Tensor]

    Mean of batch IDs latent variables.

  • z_s_logvarList[torch.Tensor]

    Log-variance of batch IDs latent variables.

Return type:

Optional[Tuple[list, list]]

forward(data: Dict[str, Tensor]) Tuple[Dict[str, Tensor], Tensor | None, Tensor, Tensor, Tensor, Tensor, Tensor, Dict[str, Tensor], Dict[str, Tensor]][source]#

Forward pass for the VAE.

Parameters:

data – Dict[str, torch.Tensor] Input data dictionary containing: - ‘x’: Dict[str, torch.Tensor], modality-specific input data. - ‘e’: Dict[str, torch.Tensor], modality-specific masks. - ‘s’ (optional): torch.Tensor, dimensions of the output classes for each modality.

Returns:

  • x_r_preDict[str, torch.Tensor]

    Reconstructed modality-specific data.

  • s_r_preOptional[torch.Tensor]

    If ‘s’ is provided, return reconstructed batch indices. If ‘s’ is not provided, return None.

  • z_mutorch.Tensor

    Mean of the combined latent variables.

  • z_logvartorch.Tensor

    Log-variance of the combined latent variables.

  • ztorch.Tensor

    Sampled latent variables.

  • ctorch.Tensor

    Biological information variables.

  • utorch.Tensor

    Technical noise variables.

  • z_uniDict[str, torch.Tensor]

    Unified latent variables for each modality.

  • c_allDict[str, torch.Tensor]

    Modality-specific Biological information variables.

Return type:

Tuple

gen_real_data(x_r_pre: Dict[str, Tensor], sampling: bool = True) Dict[str, Tensor][source]#

Generate real data from reconstructed data.

Parameters:
  • x_r_pre – Dict[str, torch.Tensor] Dictionary of reconstructed data tensors for each modality.

  • sampling – bool, optional Whether to sample the output (default: True).

Returns:

Generated real data for each modality.

Return type:

Dict[str, torch.Tensor]

generate_unified_latent(z_x_mu: Dict[str, Tensor], z_x_logvar: Dict[str, Tensor], z_s_mu: List[Tensor], z_s_logvar: List[Tensor], c: Tensor) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]#

Generate unified latent variables and modality-specific representations.

Parameters:
  • z_x_mu – Dict[str, torch.Tensor] Means of modality-specific latent variables.

  • z_x_logvar – Dict[str, torch.Tensor] Log-variances of modality-specific latent variables.

  • z_s_mu – List[torch.Tensor] Mean of the batch-ID latent variables.

  • z_s_logvar – List[torch.Tensor] Log-variance of the batch-ID latent variables.

  • c – torch.Tensor Biological information.

Returns:

  • z_uniDict[str, torch.Tensor]:

    Collection of latent variables for the unimodal inputs.

  • c_allDict[str, torch.Tensor]:

    Collection of biological information for the unimodal and joint inputs.

Return type:

Tuple

get_dim_h() Dict[str, List[int]][source]#

Compute hidden dimensions for each modality.

Returns:

A dictionary containing the hidden dimensions for each modality.

Return type:

Dict[str, List[int]]

static poe(mus: List[Tensor], logvars: List[Tensor]) Tuple[Tensor, Tensor][source]#

Product of Experts (PoE) for combining Gaussian distributions.

Parameters:
  • mus – list of torch.Tensor List of mean tensors for each Gaussian.

  • logvars – list of torch.Tensor List of log-variance tensors for each Gaussian.

Returns:

  • combined_mean: torch.Tensor

    Mean of the combined Gaussian distribution.

  • combined_logvar: torch.Tensor

    Log-variance of the combined Gaussian distribution.

Return type:

Tuple

static sample(name: str, data: Tensor, sampling: bool) Tensor[source]#

Map a sampling function based on the distribution name.

Parameters:
  • name – str Name of the distribution.

  • data – torch.Tensor Input data tensor.

  • sampling – bool Whether to apply sampling.

Returns:

torch.Tensor

Sampled or original data tensor.

static sample_gaussian(mu: Tensor, logvar: Tensor) Tensor[source]#

Sample from a Gaussian distribution using the reparameterization trick.

Parameters:
  • mu – torch.Tensor Mean of the Gaussian distribution.

  • logvar – torch.Tensor Log-variance of the Gaussian distribution.

Returns:

torch.Tensor

Sampled tensor.

sample_latent(z_mu: Tensor, z_logvar: Tensor) Tensor[source]#

Sample latent variables from a Gaussian distribution.

Parameters:
  • z_mu – torch.Tensor Mean of the latent variables of shape (batch_size, latent_dim).

  • z_logvar – torch.Tensor Log-variance of the latent variables of shape (batch_size, latent_dim).

Returns:

Sampled latent variables of shape (batch_size, latent_dim).

Return type:

torch.Tensor