Source code for scmiracle.model

import os
import datetime

from typing import Dict, List, Optional, Tuple
import natsort

import toml
import json
import pandas as pd
import scanpy as sc
from tqdm import tqdm
import random
from matplotlib import pyplot as plt
from scipy.sparse import csr_matrix

import torch
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset, Dataset
import lightning as L
from pytorch_lightning.utilities import rank_zero_only

import logging
logging.basicConfig(level=logging.INFO)

# Project-Specific Imports
from .data import MyDistributedSampler, MultiBatchSampler, MultiModalDataset, MultiBatchContinualLearningSampler
from .utils import *
from .nn import MLP, Layer1D, distribution_registry, transform_registry
from copy import deepcopy


[docs] class Encoder(nn.Module): """ Encoder class for multi-modal data with modality-specific pre-processing, encoding, and shared encoding layers. Parameters: dims_x : Dict[str, list] Input dimensions for each modality (e.g, {'rna':[1000], 'adt':[100]}). dims_h : Dict[str, list] Hidden dimensions for each modality after pre-encoding (e.g, {'rna':256, 'adt':256}). dim_z : int Latent dimension size (e.g, 32). norm : str Normalization type (e.g., 'ln' for LayerNorm). out_trans : str Output activation function (e.g., 'mish'). drop : float Dropout rate. kwargs : Dict[str, Any] Additional modality-specific configurations. Notes: By default, RNA and ADT data are log1p-transformed in the encoder and will be exponentiated after decoding. To skip this step, modify the configuration file. See parameter 'trsf_before_enc_'. """ def __init__( self, dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs, ): super(Encoder, self).__init__() self.dims_x = dims_x self.dims_h = dims_h self.dim_z = dim_z self.norm = norm self.out_trans = out_trans self.drop = drop # Dynamically set additional arguments as attributes for key, value in kwargs.items(): setattr(self, key, value) # Extract transformations to apply before encoding self.trsf_before_enc = filter_keys(kwargs, 'trsf_before_enc') # Shared encoder across all modalities shared_encoder = MLP( self.dims_shared_enc + [self.dim_z * 2], hid_norm=self.norm, hid_drop=self.drop, ) # Initialize modality-specific encoders # mod1 -> (opt) transform[mod1] -> (opt) pre_encoder[mod1] -> # (opt) transform_concat[mod1] -> indiv_enc[mod1] -> share_encoder -> z_mod1 self.pre_encoders = nn.ModuleDict() # Modality-specific pre-encoding layers self.transform_concat = nn.ModuleDict() # Post-concatenation layers encoders = {} # Final encoders for each modality for modality, input_dims in dims_x.items(): # For truncated input, such as ATAC if len(input_dims) > 1: self.pre_encoders[modality] = nn.ModuleList([ MLP([dim] + kwargs[f'dims_before_enc_{modality}'], hid_norm=self.norm, hid_drop=self.drop) for dim in input_dims ]) self.transform_concat[modality] = Layer1D(self.dims_h[modality], self.norm, self.out_trans, self.drop) # Create individual encoder for the modality indiv_enc = MLP( [self.dims_h[modality][0], self.dims_shared_enc[0]], out_trans=self.out_trans, norm=self.norm, drop=self.drop, ) encoders[modality] = nn.Sequential(indiv_enc, shared_encoder) self.encoders = nn.ModuleDict(encoders)
[docs] def forward( self, data: Dict[str, torch.Tensor], mask: Dict[str, torch.Tensor] ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ Forward pass for the encoder. Parameters: data : Dict[str, torch.Tensor] Input data for each modality. mask : Dict[str, torch.Tensor] Masks for each modality. Returns: Tuple: - z_x_mu : Dict[str, torch.Tensor] Mean values for latent space for each modality. - z_x_logvar : Dict[str, torch.Tensor] Log-variance values for latent space for each modality. """ data = data.copy() mask = mask.copy() # Apply transformations before encoding for modality in data.keys(): if f'trsf_before_enc_{modality}' in self.trsf_before_enc: transformation = self.trsf_before_enc[f'trsf_before_enc_{modality}'] data[modality] = transform_registry.get(transformation)(data[modality]) # Apply masks to data for modality, mask_value in mask.items(): data[modality] *= mask_value # Pre-encode and concatenate if necessary, for truncated inputs for modality in data.keys(): if modality in self.pre_encoders: # Split and process individual dimensions batches = data[modality].split(self.dims_x[modality], dim=1) processed_batches = [ self.pre_encoders[modality][i](batch) for i, batch in enumerate(batches) ] # Concatenate processed batches and transform data[modality] = self.transform_concat[modality](torch.cat(processed_batches, dim=1)) # Encode data and split into mean and log-variance z_x_mu, z_x_logvar = {}, {} for modality, modality_data in data.items(): encoded = self.encoders[modality](modality_data) z_x_mu[modality], z_x_logvar[modality] = encoded.split(self.dim_z, dim=1) return z_x_mu, z_x_logvar
[docs] class Decoder(nn.Module): """ Decoder class for multi-modal data with shared and modality-specific decoding layers. Parameters: dims_x : Dict[str, list] Output dimensions for each modality. dims_h : Dict[str, list] Hidden dimensions for each modality. dim_z : int Latent dimension size. norm : str Normalization type (e.g., 'ln' for LayerNorm). out_trans : str Output activation function (e.g., 'relu'). drop : float Dropout rate. kwargs : Dict[str, Any] Additional modality-specific configurations. """ def __init__( self, dims_x: Dict[str, list], dims_h: Dict[str, list], dim_z: int, norm: str, out_trans: str, drop: float, **kwargs, ): super(Decoder, self).__init__() self.dims_x = dims_x self.dims_h = dims_h self.dim_z = dim_z self.norm = norm self.out_trans = out_trans self.drop = drop # Dynamically set additional arguments as attributes for key, value in kwargs.items(): setattr(self, key, value) # z -> shared_decoder -> (opt) post_decoders[mod1] -> (opt) transform_concat[mod1] -> mod1 # Shared decoder layer total_hidden_dims = sum(dim[0] for dim in dims_h.values()) self.shared_decoder = MLP( [self.dim_z] + self.dims_shared_dec + [total_hidden_dims], hid_norm=self.norm, hid_drop=self.drop, ) # Modality-specific decoders self.post_decoders = nn.ModuleDict() self.transform_concat = nn.ModuleDict() for modality, output_dims in dims_x.items(): # Modality-specific post-decoding layers if len(output_dims) > 1: self.post_decoders[modality] = nn.ModuleList([ MLP(kwargs[f'dims_after_dec_{modality}'] + [dim], hid_norm=self.norm, hid_drop=self.drop) for dim in output_dims ]) # Layer to process concatenated outputs self.transform_concat[modality] = Layer1D(self.dims_h[modality], self.norm, self.out_trans, self.drop)
[docs] def forward(self, latent_data: torch.Tensor) -> Dict[str, torch.Tensor]: """ Forward pass for the decoder. Parameters: latent_data : torch.Tensor Latent variable input tensor of shape (batch_size, dim_z). Returns: Dict[str, torch.Tensor] : Decoded outputs for each modality. """ # Pass through the shared decoder shared_output = self.shared_decoder(latent_data) # Split shared decoder output into modality-specific chunks modality_outputs = shared_output.split( [dim[0] for dim in self.dims_h.values()], dim=1, ) # Create a dictionary to hold the modality-specific outputs data_dict = {modality: output for modality, output in zip(self.dims_x.keys(), modality_outputs)} # Process each modality-specific output for modality, post_decoders in self.post_decoders.items(): # Apply transformation layer processed_output = self.transform_concat[modality](data_dict[modality]) batches = processed_output.split(self.__dict__[f'dims_after_dec_{modality}'][0], dim=1) # Apply modality-specific post-decoders data_dict[modality] = torch.cat( [post_decoders[i](batch) for i, batch in enumerate(batches)], dim=1, ) # Apply activation functions based on distribution for modality, output in data_dict.items(): distribution = self.__dict__[f'distribution_dec_{modality}'] activation_fn = distribution_registry.get_activate(distribution) data_dict[modality] = activation_fn(output) return data_dict
[docs] class S_Encoder(nn.Module): """ Encoder for batch ID latent variables. Parameters: n_batches : int Number of distinct batches. dims_enc_s : List[int] List of dimensions for hidden layers in the encoder. dim_z : int Latent dimension size for the latent. norm : str Normalization type (e.g., 'ln' for LayerNorm). drop : float Dropout rate. """ def __init__( self, n_batches: int, dims_enc_s: List[int], dim_z: int, norm: str, drop: float ): super(S_Encoder, self).__init__() self.n_batches = n_batches self.dims_enc_s = dims_enc_s self.dim_z = dim_z self.norm = norm self.drop = drop # Define the encoder MLP self.s_encoder = MLP( [self.n_batches] + self.dims_enc_s + [self.dim_z * 2], hid_norm=self.norm, hid_drop=self.drop, )
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """ Forward pass for S_Encoder. Parameters: data : torch.Tensor Input tensor of shape (batch_size, 1), containing batch indices. Returns: torch.Tensor : Encoded tensor of shape (batch_size, dim_z * 2). """ # One-hot encode the batch indices one_hot_data = nn.functional.one_hot(data.squeeze(1), num_classes=self.n_batches).float() # Pass through the encoder return self.s_encoder(one_hot_data)
[docs] class S_Decoder(nn.Module): """ Decoder for reconstructing batch ID. Parameters: n_batches : int Number of distinct batches. dims_dec_s : List[int] List of dimensions for hidden layers in the decoder. dim_u : int Latent dimension size for the input (e.g, 2). norm : str Normalization type (e.g., 'ln' for LayerNorm). drop : float Dropout rate. """ def __init__( self, n_batches: int, dims_dec_s: List[int], dim_u: int, norm: str, drop: float): super(S_Decoder, self).__init__() self.n_batches = n_batches self.dims_dec_s = dims_dec_s self.dim_u = dim_u self.norm = norm self.drop = drop # Define the decoder MLP self.s_decoder = MLP( [self.dim_u] + self.dims_dec_s + [self.n_batches], hid_norm=self.norm, hid_drop=self.drop, )
[docs] def forward(self, data: torch.Tensor) -> torch.Tensor: """ Forward pass for S_Decoder. Parameters: data : torch.Tensor Latent input tensor of shape (batch_size, dim_u). Returns: torch.Tensor : Reconstructed tensor of shape (batch_size, n_batches). """ return self.s_decoder(data)
[docs] class VAE(nn.Module): """ Variational Autoencoder (VAE) for multi-modal data, supporting batch correction and sampling from distributions. Parameters: dims_x : Dict[str, list] Input dimensions for each modality, e.g {'rna'=[1000], 'adt'=[100], 'atac'=[10,10,10]}. dims_s : Dict[str, int] Dimensions of the classes for each modality. kwargs : Dict[str, Any] Additional configurations for encoders, decoders, and other modules. """ def __init__(self, dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs): super(VAE, self).__init__() self.dims_x = dims_x self.dims_s = dims_s self.mods = set(dims_x.keys()) logging.debug(f'Initializing VAE with modalities: {self.mods}') logging.debug(f'Initializing VAE with dims_s: {self.dims_s}') logging.debug(f'Initializing VAE with dims_x: {self.dims_x}') # Dynamically set additional arguments for key, value in kwargs.items(): setattr(self, key, value) self.available_mods = set(self.dims_x.keys()) self.dim_z = self.dim_c + self.dim_u self.dims_h = self.get_dim_h() self.n_batches = dims_s['joint'] # Initialize modules self.encoder = Encoder(self.dims_x, self.dims_h, self.dim_z, self.norm, self.out_trans, self.drop, **filter_keys(self.__dict__, '_enc')) self.decoder = Decoder(self.dims_x, self.dims_h, self.dim_z, self.norm, self.out_trans, self.drop, **filter_keys(self.__dict__, '_dec')) self.s_encoder = S_Encoder(self.n_batches, self.dims_enc_s, self.dim_z, self.norm, self.drop) self.s_decoder = S_Decoder(self.n_batches, self.dims_dec_s, self.dim_u, self.norm, self.drop) # Batch correction and sampling configurations self.batch_correction = False self.u_centroid = None self.drop_s = False self.sampling = False self.sample_num = 0
[docs] def forward(self, data: Dict[str, torch.Tensor] ) -> Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ Forward pass for the VAE. Parameters: data : Dict[str, torch.Tensor] Input data dictionary containing: - 'x': Dict[str, torch.Tensor], modality-specific input data. - 'e': Dict[str, torch.Tensor], modality-specific masks. - 's' (optional): torch.Tensor, dimensions of the output classes for each modality. Returns: Tuple: - x_r_pre : Dict[str, torch.Tensor] Reconstructed modality-specific data. - s_r_pre : Optional[torch.Tensor] If 's' is provided, return reconstructed batch indices. If 's' is not provided, return None. - z_mu : torch.Tensor Mean of the combined latent variables. - z_logvar : torch.Tensor Log-variance of the combined latent variables. - z : torch.Tensor Sampled latent variables. - c : torch.Tensor Biological information variables. - u : torch.Tensor Technical noise variables. - z_uni : Dict[str, torch.Tensor] Unified latent variables for each modality. - c_all : Dict[str, torch.Tensor] Modality-specific Biological information variables. """ x = data['x'] e = data['e'] s = None # Handle batch-specific information. See https://github.com/labomics/midas/issues/12. if not self.drop_s and 's' in data: s_drop_rate = self.s_drop_rate if self.training else 0 if torch.rand([]).item() < 1 - s_drop_rate: s = data['s'] # Encode data # check device: logging.debug(f"x device: {next(iter(x.values())).device}") logging.debug(f"model device: {next(self.parameters()).device}") z_x_mu, z_x_logvar = self.encoder(x, e) z_s_mu, z_s_logvar = self.encode_batch(s) # Combine latent variables using Product of Experts try: z_mu, z_logvar = self.poe(list(z_x_mu.values()) + z_s_mu, list(z_x_logvar.values()) + z_s_logvar) except: logging.debug(z_x_mu, z_s_mu, x, e, s) # Sample latent variables z = self.sample_latent(z_mu, z_logvar) # Split latent variables into c and u c, u = z.split([self.dim_c, self.dim_u], dim=1) # Perform batch correction if enabled if self.batch_correction: z[:, self.dim_c:] = self.u_centroid.type_as(z).unsqueeze(0) # Decode data x_r_pre = self.decoder(z) # Decode batch-specific information s_r_pre = self.s_decoder(u) if s is not None else None # Generate unified latent variables and modality-specific c z_uni, c_all = self.generate_unified_latent(z_x_mu, z_x_logvar, z_s_mu, z_s_logvar, c) return x_r_pre, s_r_pre, z_mu, z_logvar, z, c, u, z_uni, c_all
[docs] def encode_batch(self, s: torch.Tensor) -> Optional[Tuple[list, list]]: """ Encode batch IDs latent variables. Parameters: s : torch.Tensor Batch IDs. Returns: Optional[Tuple[list, list]]: - z_s_mu : List[torch.Tensor] Mean of batch IDs latent variables. - z_s_logvar : List[torch.Tensor] Log-variance of batch IDs latent variables. """ if s is not None: z_s_mu, z_s_logvar = self.s_encoder(s['joint']).split(self.dim_z, dim=1) return [z_s_mu], [z_s_logvar] return [], []
[docs] def sample_latent(self, z_mu: torch.Tensor, z_logvar: torch.Tensor) -> torch.Tensor: """ Sample latent variables from a Gaussian distribution. Parameters: z_mu : torch.Tensor Mean of the latent variables of shape (batch_size, latent_dim). z_logvar : torch.Tensor Log-variance of the latent variables of shape (batch_size, latent_dim). Returns: torch.Tensor: Sampled latent variables of shape (batch_size, latent_dim). """ if self.training: return self.sample_gaussian(z_mu, z_logvar) elif self.sampling and self.sample_num > 0: z_mu_expand = z_mu.unsqueeze(1) z_logvar_expand = z_logvar.unsqueeze(1).expand(-1, self.sample_num, self.dim_z) return self.sample_gaussian(z_mu_expand, z_logvar_expand).reshape(-1, self.dim_z) return z_mu
[docs] def generate_unified_latent( self, z_x_mu: Dict[str, torch.Tensor], z_x_logvar: Dict[str, torch.Tensor], z_s_mu: List[torch.Tensor], z_s_logvar: List[torch.Tensor], c: torch.Tensor, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ Generate unified latent variables and modality-specific representations. Parameters: z_x_mu : Dict[str, torch.Tensor] Means of modality-specific latent variables. z_x_logvar : Dict[str, torch.Tensor] Log-variances of modality-specific latent variables. z_s_mu : List[torch.Tensor] Mean of the batch-ID latent variables. z_s_logvar : List[torch.Tensor] Log-variance of the batch-ID latent variables. c : torch.Tensor Biological information. Returns: Tuple: - z_uni : Dict[str, torch.Tensor]: Collection of latent variables for the unimodal inputs. - c_all : Dict[str, torch.Tensor]: Collection of biological information for the unimodal and joint inputs. """ z_uni = {} c_all = {} for modality, z_x_mu_mod in z_x_mu.items(): # Combine modality-specific and batch-specific latent variables z_uni_mu, z_uni_logvar = self.poe([z_x_mu_mod] + z_s_mu, [z_x_logvar[modality]] + z_s_logvar) # fix here z_uni[modality] = self.sample_latent(z_uni_mu, z_uni_logvar) # Extract shared latent representation (biological information) c_all[modality] = z_uni[modality][:, :self.dim_c] # Add joint representation c_all['joint'] = c return z_uni, c_all
[docs] def get_dim_h(self) -> Dict[str, List[int]]: """ Compute hidden dimensions for each modality. Returns: Dict[str, List[int]]: A dictionary containing the hidden dimensions for each modality. """ dims_h = self.dims_x.copy() # Adjust dimensions based on pre-encoding layers for key in filter_keys(self.__dict__, 'dims_before_enc_'): modality = key.split('_')[-1] if (modality in self.dims_x) and (len(self.dims_x[modality]) > 1): dims_h[modality] = [sum([self.__dict__[key][-1]] * len(self.dims_x[modality]))] return dims_h
[docs] def gen_real_data(self, x_r_pre: Dict[str, torch.Tensor], sampling: bool = True) -> Dict[str, torch.Tensor]: """ Generate real data from reconstructed data. Parameters: x_r_pre : Dict[str, torch.Tensor] Dictionary of reconstructed data tensors for each modality. sampling : bool, optional Whether to sample the output (default: True). Returns: Dict[str, torch.Tensor]: Generated real data for each modality. """ x_r = {} for modality, tensor in x_r_pre.items(): # Apply inverse transformations if needed if f'trsf_before_enc_{modality}' in self.__dict__: tensor = reverse_trsf(self.__dict__[f'trsf_before_enc_{modality}'].split('_')[-1], tensor) # Apply sampling or directly return the data x_r[modality] = self.sample( self.__dict__[f'distribution_dec_{modality}'].split('_')[-1], tensor, sampling) return x_r
[docs] @staticmethod def sample(name: str, data: torch.Tensor, sampling: bool) -> torch.Tensor: """ Map a sampling function based on the distribution name. Parameters: name : str Name of the distribution. data : torch.Tensor Input data tensor. sampling : bool Whether to apply sampling. Returns: torch.Tensor Sampled or original data tensor. """ if sampling: return distribution_registry.get_sampling(name)(data) return data
[docs] @staticmethod def sample_gaussian(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: """ Sample from a Gaussian distribution using the reparameterization trick. Parameters: mu : torch.Tensor Mean of the Gaussian distribution. logvar : torch.Tensor Log-variance of the Gaussian distribution. Returns: torch.Tensor Sampled tensor. """ std = (0.5 * logvar).exp() eps = torch.randn_like(std) return mu + std * eps
[docs] @staticmethod def poe(mus: List[torch.Tensor], logvars: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """ Product of Experts (PoE) for combining Gaussian distributions. Parameters: mus : list of torch.Tensor List of mean tensors for each Gaussian. logvars : list of torch.Tensor List of log-variance tensors for each Gaussian. Returns: Tuple : - combined_mean: torch.Tensor Mean of the combined Gaussian distribution. - combined_logvar: torch.Tensor Log-variance of the combined Gaussian distribution. """ # Add prior distributions with zero mean and unit variance try: mus = [torch.zeros_like(mus[0])] + mus except: logging.debug(mus) logvars = [torch.zeros_like(logvars[0])] + logvars # Calculate precision and combined precision precisions = torch.exp(-torch.stack(logvars, dim=1)) # Shape: (batch_size, num_experts, latent_dim) precision_sum = precisions.sum(dim=1) # Calculate combined mean and variance weighted_means = (torch.stack(mus, dim=1) * precisions).sum(dim=1) combined_mean = weighted_means / precision_sum combined_logvar = torch.log(1 / precision_sum) return combined_mean, combined_logvar
[docs] class Discriminator(nn.Module): """ Discriminator class for multi-modal latent variables. Parameters: dims_x : Dict[str, list] Input dimensions for each modality. dims_s : Dict[str, int] Dimensions of the classes for each modality. kwargs : Dict[str, Any] Additional configurations, such as hidden layer sizes, dropout rate, and normalization type. """ def __init__(self, dims_x: Dict[str, list], dims_s: Dict[str, int], **kwargs): super(Discriminator, self).__init__() self.dims_x = dims_x self.dims_s = dims_s # Dynamically set additional arguments as attributes for key, value in kwargs.items(): setattr(self, key, value) # Combine modality keys with 'joint' modality self.modalities = list(self.dims_x.keys()) + ['joint'] # Create predictors for each modality self.predictors = nn.ModuleDict({ modality: MLP( [self.dim_c] + self.dims_dsc + [self.dims_s[modality]], hid_norm=self.norm, hid_drop=self.drop ) for modality in self.modalities }) # Cross-entropy loss function self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum') # log_softmax + nll
[docs] def forward(self, latent_inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Forward pass for the discriminator. Parameters: latent_inputs : Dict[str, torch.Tensor] Dictionary of latent inputs for each modality, where keys are modality names and values are tensors of shape (batch_size, dim_c). Returns: Dict[str, torch.Tensor] : Dictionary of logits for each modality, where keys are modality names and values are tensors of shape (batch_size, dims_s[modality]). """ return {modality: self.predictors[modality](latent_input) for modality, latent_input in latent_inputs.items()}
[docs] def calculate_loss(self, predictions: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculate cross-entropy loss for all modalities. Parameters: predictions : Dict[str, torch.Tensor] Dictionary of predicted logits for each modality. targets : Dict[str, torch.Tensor] Dictionary of ground truth labels for each modality. Returns: torch.Tensor : Total normalized loss. """ total_loss = sum( self.cross_entropy_loss(pred, targets[modality].squeeze(1)) for modality, pred in predictions.items() ) # Normalize the total loss by the batch size of the joint modality batch_size = predictions['joint'].size(0) return total_loss / batch_size
[docs] class MIDAS(L.LightningModule): """ MIDAS processes mosaic single-cell data into imputed and batch-corrected data for multimodal analysis. Attributes: net : VAE Variational Autoencoder for multi-modal data encoding and decoding. dsc : Discriminator Discriminator for distinguishing latent variables across batches. configs : Dict[str, Any] Model and training configurations dynamically set as attributes. automatic_optimization : bool Controls whether optimization is automatic or manually defined. Always True. """ def __init__(self): super(MIDAS, self).__init__() self.net = MIDAS.net self.dsc = MIDAS.dsc for key, value in self.configs.items(): setattr(self, key, value) # Disable automatic optimization to manually control training steps. Always True. self.automatic_optimization = False
[docs] def update(model1: nn.Module, model2: nn.Module): """ Align and update weights from model2 to model1 by left-aligning the weights. Args: model1 (nn.Module): target model to be updated. model2 (nn.Module): source model from which weights are taken. """ model1_state_dict = model1.state_dict() model2_state_dict = model2.state_dict() for name, param2 in model2_state_dict.items(): if name in model1_state_dict: param1 = model1_state_dict[name] if isinstance(param1, torch.Tensor) and isinstance(param2, torch.Tensor): min_dims = [min(d1, d2) for d1, d2 in zip(param1.shape, param2.shape)] slice_tuple = tuple(slice(0, dim) for dim in min_dims) param1[slice_tuple].copy_(param2[slice_tuple]) logging.debug(f"Updated parameter '{name}' with left-aligned weights.") else: if param1.shape == param2.shape: param1.copy_(param2) else: logging.debug(f"Warning: Buffer '{name}' has different shapes ({param1.shape} vs {param2.shape}), skipping direct copy.") else: logging.debug(f"Warning: Parameter '{name}' from model2 not found in model1.") model1.load_state_dict(model1_state_dict) logging.info("Model1 updated successfully with weights from model2 (left-aligned strategy).") return model1
[docs] @classmethod def configure_data( cls, configs: dict, datalist: List[Dataset], dims_x: Dict[str, list], dims_s: Dict[str, int] , s_joint: List[Dict[str, int]], combs: List[List[str]], batch_size: int = 256, n_save:int = 500, save_model_path: str = './saved_models/', sampler_type:str = 'auto', viz_umap_tb=False, batch_names=None, ) -> 'MIDAS': """ Configure the data and model parameters for training. Parameters: configs : dict, Configurations of the model. datalist : List[Dataset] List of datasets to be used for training. dims_x : Dict[str, list] Dictionary specifying the dimensions of input features for each modality. dims_s : Dict[str, int] Dimensions of the classes for each modality. s_joint : List[Dict[str, int]] Modality ID for each batch. combs : List[List[str]] Combinations of modalities. batch_size : int, optional Size of each training batch, by default 256. n_save : int, optional Interval (in epochs) for saving model checkpoints, by default 500. save_model_path : str, optional Directory path for saving model checkpoints, by default './saved_models/'. sampler_type : str, optional Type of sampler to use, by default 'auto'. For 'ddp', use distributed sampler. viz_umap_tb: bool, optional Whether to visualize UMAP embeddings in TensorBoard, by default False. batch_names: list, optional List of batch names, by default None. Returns: class 'MIDAS': Returns MIDAS instance. """ # Set class-level attributes cls.configs = configs # check config atac_dims = dims_x.get('atac', None) if atac_dims is not None and len(atac_dims) == 1: logging.warning( f"Detected ATAC with only one dimension [{atac_dims[0]}]. " "This will cause the data to be encoded directly instead of by chromosome, as described in our paper. " "We recommend splitting the ATAC data by chromosome." ) if 'dims_before_enc_atac' in configs and 'dims_after_dec_atac' in configs: logging.error( 'Invalid ATAC configuration: both "dims_before_enc_atac" and "dims_after_dec_atac" exist in the configs, ' 'but len(dims_x["atac"]) == 1. To forcibly encode ATAC data directly, please remove these settings from configs.' ) exit() if batch_names is None: batch_names = ['batch_%d' for i in range(len(datalist))] cls.batch_names = batch_names cls.sampler_type = sampler_type cls.datalist = datalist cls.dims_s = dims_s cls.s_joint = s_joint cls.combs = combs cls.mods = list(dims_x.keys()) # Extract modality names from dims_x keys cls.save_model_path = save_model_path cls.batch_size = batch_size cls.n_save = n_save cls.viz_umap_tb = viz_umap_tb if hasattr(cls, 'net'): logging.info('Loading pre-defined network structure...') net = VAE(cls.dims_x, cls.dims_s, **cls.configs) cls.net = cls.update(net, cls.net) else: logging.info('Defining new network structure...') cls.net = VAE(dims_x, cls.dims_s, **cls.configs) cls.dsc = Discriminator(dims_x, cls.dims_s, **cls.configs) return cls()
[docs] def train_dataloader(self) -> DataLoader: """ Create a DataLoader for training, using the appropriate sampler. Returns: DataLoader : Configured DataLoader instance for training. """ # Concatenate all datasets try: dataset = ConcatDataset(self.datalist) logging.info(f'Total number of samples: {len(dataset)} from {len(self.datalist)} datasets.') except Exception as e: raise ValueError('Failed to concatenate datasets. Please check the input datalist.') from e if self.sampler_type == 'cl': sampler = MultiBatchContinualLearningSampler( dataset, batch_size=self.batch_size, # n_max=self.n_max, n_current_datasets=self.current_batches, n_replay_datasets=self.replay_batches ) else: # Select the appropriate sampler if self.sampler_type == 'ddp': logging.info('Using Distributed Data Parallel (DDP) sampler.') sampler = MyDistributedSampler(dataset, batch_size=self.batch_size, n_max=self.n_max) else: logging.info('Using MultiBatchSampler for data loading.') sampler = MultiBatchSampler(dataset, batch_size=self.batch_size, n_max=self.n_max) # Create the DataLoader try: train_loader = DataLoader( dataset, sampler=sampler, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers ) logging.info(f'DataLoader created with batch size {self.batch_size} and {self.num_workers} workers.') except Exception as e: raise RuntimeError('Failed to create DataLoader. Check DataLoader configuration.') from e logging.debug(f'DataLoader: {len(train_loader)}') return train_loader
[docs] def configure_optimizers(self) -> List[torch.optim.Optimizer]: """ Configure optimizers for the MIDAS model. Returns: List[torch.optim.Optimizer] : List of optimizers for the network and discriminator. """ logging.debug(f'net:{self.net}') logging.debug(f'dsc:{self.dsc}') self.net_optim = getattr(torch.optim, self.optim_net)(self.net.parameters(), lr=self.lr_net) self.dsc_optim = getattr(torch.optim, self.optim_dsc)(self.dsc.parameters(), lr=self.lr_dsc) # 如果你已经加载过 optimizer state dict,在这里 load if self.load_optimizer_state: self.net_optim.load_state_dict(self.loaded_net_optim_state) self.dsc_optim.load_state_dict(self.loaded_dsc_optim_state) return [self.net_optim, self.dsc_optim]
[docs] def training_step(self, batch: Dict[str, Dict[str, torch.Tensor]], batch_idx: int) -> torch.Tensor: """ Executes a single training step for MIDAS. Parameters: batch : Dict[str, Dict[str, torch.Tensor]] Input batch containing modality data, batch indices, and masks. batch_idx : int Index of the current training batch. Returns: torch.Tensor : Total VAE loss for the current batch. """ # Forward pass through the VAE logging.debug(f"Training step - batch index: {batch_idx}") logging.debug(f"Input: {batch}") x_r_pre, s_r_pre, z_mu, z_logvar, z, c, u, z_uni, c_all = self.net(batch) logging.debug(f'Current batch: {batch['s']['joint'][0]}') c_all['joint'] = c # Compute reconstruction loss recon_loss, recon_dict = self.calc_recon_loss( batch['x'], batch['s']['joint'], batch['e'], x_r_pre, s_r_pre, filter_keys(self.__dict__, 'distribution_dec_'), filter_keys(self.__dict__, 'lam_recon_') ) recon_loss *= self.lam_recon # Compute KLD loss kld_loss = self.calc_kld_z_loss( self.dim_c, self.dim_u, self.lam_kld_c, self.lam_kld_u, z_mu, z_logvar ) * self.lam_kld # Compute consistency loss consistency_loss = self.calc_consistency_loss(z_uni) * self.lam_alignment # Compute total VAE loss loss_net = recon_loss + kld_loss + consistency_loss if self.train_mod == 'offline': scale = 1.0 else: scale = self.scale[batch['s']['joint'][0][0]] logging.debug(f'Scale: {scale:.4f}') # Train discriminator for n_iter_disc iterations for _ in range(self.n_iter_disc): self.train_discriminator(c_all, batch['s'], scale) # Compute adversarial loss for the VAE s_pred = self.dsc(c_all) loss_dsc = self.calc_dsc_loss(s_pred, batch['s']) * self.lam_dsc loss_net = loss_net - loss_dsc * self.lam_adv self.update_model(loss_net*scale, self.net, self.net_optim, self.grad_clip) # Log training losses self.log_losses(recon_loss, kld_loss, consistency_loss, loss_net, loss_dsc, recon_dict) return loss_net
[docs] def train_discriminator(self, c_all: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], scale: float = 1.0) -> None: """ Train the discriminator with modality-specific latent representations. Parameters: c_all : Dict[str, torch.Tensor] Dictionary of latent representations for each modality. targets : Dict[str, torch.Tensor] Ground truth batch labels for each modality. """ s_pred = self.dsc(detach_tensors(c_all)) loss_dsc = self.calc_dsc_loss(s_pred, targets) * self.lam_dsc * scale self.update_model(loss_dsc, self.dsc, self.dsc_optim, self.grad_clip)
[docs] def pack_data( self, des_dir, subsample_num: list = None, subsample='random', format='mtx', use_transform: bool = False, return_idx = True ): """ Packs data from datalist into a specified directory structure, with an option for proportional subsampling. """ if os.path.exists(des_dir): logging.warning(f'Directory {des_dir} already exists. Contents may be overwritten.') # Determine if proportional sampling should be performed. all_indices = [] if subsample_num is not None and subsample == 'BTS': if sum(len(self.datalist[i]) for i in range(len(self.datalist))) > sum(subsample_num): z = self.predict()['z']['joint'][:, :self.dim_c] for batch in range(len(self.datalist)): if subsample_num[batch] < len(self.datalist[batch]): start_idx = sum(len(self.datalist[i]) for i in range(batch)) end_idx = start_idx + len(self.datalist[batch]) id = sorted(BallTreeSubsample(z[start_idx:end_idx], subsample_num[batch])) all_indices.append([int(x) for x in id]) else: for batch in range(len(self.datalist)): all_indices.append(list(range(len(self.datalist[batch])))) elif subsample_num is not None and subsample in ['random']: for batch in range(len(self.datalist)): if subsample_num[batch] < len(self.datalist[batch]): all_indices.append(sorted(np.random.permutation(len(self.datalist[batch])[:subsample_num[batch]]))) else: for batch in range(len(self.datalist)): all_indices.append(list(range(len(self.datalist[batch])))) for batch_id, data in enumerate(self.datalist): logging.info('Processing batch %d: %s' % (batch_id, str(self.combs[batch_id]))) current_data = data if subsample_num is not None: if len(data) > 0: if subsample is not None: # Generate random indices and create a Subset. current_data = data.__subset__(all_indices[batch_id]) logging.info(f"Sampling {len(current_data)} of {len(data)} items from batch {batch_id}.") else: logging.warning(f"Batch {batch_id} is empty. Skipping.") if not use_transform: current_data.transform = {} if format in ['mtx', 'csv']: data_loader = DataLoader(current_data, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers) elif format == 'vec': data_loader = DataLoader(current_data, shuffle=False, batch_size=1, num_workers=self.num_workers) if len(current_data) > 0: fname_fmt = get_name_fmt(len(data_loader)) + '.csv' else: raise ValueError(f'Unsupported format: {format}. Supported formats are "mtx", "csv", and "vec".') cum_data = {} last_batch_data = None for i, batch_data in enumerate(tqdm(data_loader)): os.makedirs(os.path.join(des_dir, f'batch_{batch_id}'), exist_ok=True) if format in ['mtx', 'csv']: for m in batch_data['x'].keys(): cum_data.setdefault(m, []).append(batch_data['x'][m]) else: # format == 'vec' for m in batch_data['x'].keys(): os.makedirs(os.path.join(des_dir, f'batch_{batch_id}', 'vec', m), exist_ok=True) save_tensor_to_csv(batch_data['x'][m], os.path.join(des_dir, f'batch_{batch_id}', 'vec', m, fname_fmt % i)) last_batch_data = batch_data if format in ['mtx', 'csv']: if cum_data: cum_data = {m: torch.cat(cum_data[m], dim=0) for m in cum_data.keys()} logging.debug(f'Concatenated data {cum_data}') os.makedirs(os.path.join(des_dir, f'batch_{batch_id}', 'mat'), exist_ok=True) logging.info(f'Saving concatenated data for batch {batch_id}...') for m in cum_data.keys(): if format == 'mtx': save_tensor_to_mtx(cum_data[m], os.path.join(des_dir, f'batch_{batch_id}', 'mat', m + '.mtx')) else: save_tensor_to_csv(cum_data[m], os.path.join(des_dir, f'batch_{batch_id}', 'mat', m + '.csv'), header=True, index=True) # Save mask from the last processed batch, if any data was processed. if last_batch_data: os.makedirs(os.path.join(des_dir, f'batch_{batch_id}', 'mask'), exist_ok=True) for m in last_batch_data['e'].keys(): pd.DataFrame(last_batch_data['e'][m][0].numpy()).T.to_csv(os.path.join(des_dir, f'batch_{batch_id}', 'mask', m + '.csv'), header=True, index=True) else: logging.warning(f'No data was processed for batch {batch_id} (possibly due to sampling). Mask will not be saved.') # Save feature dimensions os.makedirs(os.path.join(des_dir, 'feat'), exist_ok=True) with open(os.path.join(des_dir, 'feat', 'feat_dims.toml'), 'w') as f: toml.dump(self.dims_x, f) if subsample_num is not None: # write json of indices for batch_id in range(len(all_indices)): with open(os.path.join(des_dir, f'batch_{batch_id}','subsample.json'), 'w') as f: json.dump(all_indices[batch_id], f, indent=4) if subsample_num is not None and return_idx: return all_indices
[docs] @rank_zero_only def predict(self, return_format = 'array', save_dir: str = None, save_format = 'h5ad', joint_latent: bool = True, mod_latent: bool = False, impute: bool = False, batch_correct: bool = False, translate: bool = False, input: bool = False, )->Optional[Dict[str, Dict[str, torch.Tensor]]]: """ Run model inference to generate latent embeddings, reconstructed data, or translated modalities. The method iterates through the dataset, computes representations using the trained network, and optionally saves the results to disk or returns them in memory. Parameters ---------- return_format : str, default='array' The format of the returned data. - 'array': Returns a nested dictionary of Numpy arrays. - 'anndata': Returns a dictionary of AnnData objects (requires `scanpy`). save_dir : str, optional, default=None Directory path to save the prediction results. If None, results are not saved to disk. save_format : str, default='h5ad' File format for saving results if `save_dir` is provided. - 'h5ad': Saves as H5AD files (one per batch). - 'mtx': Saves as Matrix Market files inside batch-specific folders. joint_latent : bool, default=True If True, computes the joint latent representation (z) conditioned on all observed modalities. mod_latent : bool, default=False If True, computes latent representations conditioned on each individual modality. (Automatically set to True if `translate` is True). impute : bool, default=False If True, generates imputed data (x_impt) from the joint latent space. batch_correct : bool, default=False If True, calculates batch centroids and performs batch-effect correction on the reconstructed data (x_bc). Note: This involves a second pass through the data. translate : bool, default=False If True, performs cross-modality translation (e.g., generating Modality B from Modality A). input : bool, default=False If True, includes the original input raw data and masks in the output. Returns ------- output : Dict A dictionary where keys are batch names. 1. If return_format='array': { 'batch_name': { 'z_c': {'joint': np.ndarray, 'modality_name': ...}, # Content latent 'z_u': {'joint': np.ndarray, ...}, # Batch latent 'x_impt': {'modality_name': np.ndarray}, # Imputed data 'x_bc': {'modality_name': np.ndarray}, # Batch corrected data ... }, ... } 2. If return_format='anndata': { 'batch_name': ann_data_object } (Latent variables are stored in `adata.obsm`, masks in `adata.uns`). """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f'Predicting using device: {device}') model = self.net.to(device) model.eval() if translate: mod_latent = True logging.info('Predicting ...') pred = {} with torch.no_grad(): for batch_id, data in enumerate(self.datalist): data_loader = DataLoader(data, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers) logging.info('Processing batch %s: %s' % (self.batch_names[batch_id], str(self.combs[batch_id]))) for i, data in enumerate(tqdm(data_loader)): data = convert_tensors_to_cuda(data, device) for k in data['s'].keys(): safe_append(pred, self.batch_names[batch_id], ['s', k], data['s'][k]) # conditioned on all observed modalities # if joint_latent or batch_correct or impute: x_r_pre, _, _, _, z, _, _, *_ = model(data) # N * K safe_append(pred, self.batch_names[batch_id], ['z', 'joint'], z) if impute: x_r = model.gen_real_data(x_r_pre, sampling=False) for m in x_r.keys(): safe_append(pred, self.batch_names[batch_id], ['x_impt', m], x_r[m]) if input: # save the input for m in self.combs[batch_id]: safe_append(pred, self.batch_names[batch_id], ['x', m],(data['x'][m])) if m in data['e']: safe_append(pred, self.batch_names[batch_id], ['mask', m],(data['e'][m])) # conditioned on each individual modalities if mod_latent: for m in data['x'].keys(): input_data = { 'x': {m: data['x'][m]}, 's': data['s'], 'e': {} } if m in data['e'].keys(): input_data['e'][m] = data['e'][m] x_r_pre, _, _, _, z, c, u, *_ = model(input_data) # N * K safe_append(pred, self.batch_names[batch_id], ['z', m], z) if translate: # from a to b all_combinations = generate_all_combinations(self.mods) for input_mods, output_mods in all_combinations: flag = True input_mods_sorted = sorted(input_mods) for m in input_mods_sorted: if m not in data['x'].keys(): flag = False if flag: input_mods_sorted = sorted(input_mods) input_data = { 'x': {m: data['x'][m] for m in input_mods_sorted if m in data['x']}, 's': data['s'], 'e': {} } for m in input_mods_sorted: if m in data['e'].keys(): input_data['e'][m] = data['e'][m] x_r_pre, *_ = model(input_data) # N * K x_r = model.gen_real_data(x_r_pre, sampling=False) for mod in output_mods: safe_append(pred, self.batch_names[batch_id], ['x_trans', '_'.join(input_mods_sorted) + '_to_' + mod], x_r[mod]) if batch_correct: logging.info('Calculating u_centroid ...') u = torch.concat([torch.concat(pred[i]['z']['joint']) for i in self.batch_names])[:, self.dim_c:] s = torch.concat([torch.concat(pred[i]['s']['joint']) for i in self.batch_names]).flatten() u_mean = u.mean(dim=0, keepdim=True) u_batch_mean_list = [] for batch_id in s.unique(): u_batch = u[s == batch_id, :] u_batch_mean_list.append(u_batch.mean(dim=0)) u_batch_mean_stack = torch.stack(u_batch_mean_list, dim=0) dist = ((u_batch_mean_stack - u_mean) ** 2).sum(dim=1) model.u_centroid = u_batch_mean_list[dist.argmin()] model.batch_correction = True logging.info('Batch correction ...') for batch_id, data in enumerate(self.datalist): data_loader = DataLoader(data, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers) logging.info('Processing batch %s: %s' % (self.batch_names[batch_id], str(self.combs[batch_id]))) for i, data in enumerate(tqdm(data_loader)): data = convert_tensors_to_cuda(data, device) x_r_pre, *_ = model(data) x_r = model.gen_real_data(x_r_pre, sampling=True) for m in self.mods: safe_append(pred, self.batch_names[batch_id], ['x_bc', m], x_r[m]) # concatenate pred_b = {} for batch, data in pred.items(): pred_b[batch] = {} for var, data_v in data.items(): if var == 'z': pred_b[batch]['z_c'] = {} pred_b[batch]['z_u'] = {} else: pred_b[batch][var] = {} for m, data_m in data_v.items(): if var == 'z': pred_b[batch]['z_c'][m] = torch.cat(data_m).cpu().numpy()[:, :self.dim_c] pred_b[batch]['z_u'][m] = torch.cat(data_m).cpu().numpy()[:, self.dim_c:] elif var == 'mask': pred_b[batch][var][m] = data_m[0][0].cpu().numpy() else: pred_b[batch][var][m] = torch.cat(data_m).cpu().numpy() if save_dir is not None and save_format=='mtx': for batch, data in pred_b.items(): os.makedirs(os.path.join(save_dir, self.batch_names[batch]),exist_ok=True) for var, data_v in data.items(): for m, data_m in data_v.items(): mmwrite(os.path.join(save_dir, self.batch_names[batch], '%s_%s.mtx'%(var, m)), csr_matrix(pred_b[batch][var][m])) if (save_dir is not None and save_format=='h5ad') or return_format == 'anndata': # if group_by == 'batch': adata_all = {} for batch, data in pred_b.items(): adata = sc.AnnData(np.zeros([len(data['z_c']['joint']), 0])) adata.obs['batch'] = batch for var, data_v in data.items(): if var == 's': continue for m, data_m in data_v.items(): if m=='joint' and (not joint_latent): pass elif var=='mask': adata.uns['mask_%s'%m] = data_m else: if data_m.shape[1] > 10000: data_m = csr_matrix(data_m) adata.obsm['%s_%s' % (var, m)] = data_m adata_all[batch] = adata if (save_dir is not None and save_format=='h5ad'): for batch, adata in adata_all.items(): os.makedirs(save_dir, exist_ok=True) sc.write(os.path.join(save_dir, '%s.h5ad'%batch), adata) if return_format == 'anndata': return adata_all if not joint_latent: for batch, data in pred_b.items(): data.pop('z_c_joint', None) data.pop('z_u_joint', None) return pred_b
[docs] def on_train_epoch_end(self): """ Save a model checkpoint at the end of each training epoch with a meaningful filename. """ # Save the checkpoint periodically based on n_save if (self.current_epoch+self.start_epoch)!=0 and (self.current_epoch+1+self.start_epoch) % self.n_save == 0: os.makedirs(self.save_model_path, exist_ok=True) # Get the current timestamp timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') # Generate a descriptive checkpoint filename checkpoint_filename = f'model_epoch{self.current_epoch+1+self.start_epoch}_{timestamp}.pt' checkpoint_path = os.path.join(self.save_model_path, checkpoint_filename) # Save the checkpoint self.save_checkpoint(checkpoint_path) if self.viz_umap_tb: self.get_emb_umap()
# shutil.rmtree(self.save_model_path+'/predict'+timestamp)
[docs] def on_train_end(self): """ Save the final model checkpoint at the end of training. """ os.makedirs(self.save_model_path, exist_ok=True) timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') checkpoint_filename = f'model_epoch{self.current_epoch+self.start_epoch}_{timestamp}.pt' checkpoint_path = os.path.join(self.save_model_path, checkpoint_filename) self.save_checkpoint(checkpoint_path)
[docs] @rank_zero_only def save_checkpoint(self, checkpoint_path: str): """ Save the current model and optimizer states to a checkpoint file. Parameters: checkpoint_path : str Path to save the checkpoint file. Raises: ValueError: If `checkpoint_path` is an invalid or empty string. """ # Validate the output path if not checkpoint_path or not isinstance(checkpoint_path, str): raise ValueError('Invalid checkpoint path. Please provide a valid string.') # Create a state dictionary with model and optimizer states checkpoint_data = { 'net': self.net.state_dict(), # State dictionary of the main model 'dsc': self.dsc.state_dict(), # State dictionary of the discriminator 'optim_net': self.net_optim.state_dict(), # State dictionary of the main optimizer 'optim_dsc': self.dsc_optim.state_dict() # State dictionary of the discriminator optimizer } # Save the state dictionary to the specified path torch.save(checkpoint_data, checkpoint_path) # Inform the user of successful save logging.info(f'Checkpoint successfully saved to "{checkpoint_path}".')
[docs] def load_checkpoint(self, checkpoint_path: str, start_epoch: int = 0, **kwargs): """ Load model and optimizer states from a checkpoint file. Parameters: checkpoint_path : str Path to the checkpoint file containing saved model and optimizer states. start_epoch: int Indicate how many epoch the model has been trained. kwargs : Dict[str, Any] Additional configurations for torch.load(). Raises: AssertionError: If the provided checkpoint path does not exist. """ # Verify the checkpoint path exists assert os.path.exists(checkpoint_path), f'Checkpoint path "{checkpoint_path}" does not exist.' # Load the checkpoint file checkpoint_data = torch.load(checkpoint_path, weights_only=True, **kwargs) # Load the model state dictionaries self.net.load_state_dict(checkpoint_data['net']) self.dsc.load_state_dict(checkpoint_data['dsc']) # Load the optimizer state dictionaries self.load_optimizer_state = True self.loaded_net_optim_state = checkpoint_data['optim_net'] self.loaded_dsc_optim_state = checkpoint_data['optim_dsc'] self.start_epoch = start_epoch # influence saving name of checkpoints
[docs] @rank_zero_only def get_emb_umap( self, pred_dir: str=None, save_dir: str=None, use_mtx: bool=True, drop_c_umap: bool=False, drop_u_umap: bool=False, **kwargs) -> Tuple[List[sc.AnnData], List[plt.Figure]]: """ Generate UMAP embeddings for biological (c) and technical (u) latent variables. Parameters: pred_dir : str Directory containing predicted data. save_dir : str, optional Directory to save UMAP plots, by default './'. use_mtx : bool, optional Whether to load embeddings of mtx format, by default True. drop_c_umap : bool, optional Whether to drop the biological embedding (c) from UMAP, by default False. drop_u_umap : bool, optional Whether to drop the technical embedding (u) from UMAP, by default False. kwargs : Dict[str, Any] Additional configurations for sc.pl.umap(). Returns: Tuple : all_adata : List[AnnData] List of AnnData objects containing UMAP embeddings for each modality. all_figures : List[plt.Figure] List of UMAP figures for biological and technical embeddings. """ logging.info(f'Loading predicted data from: {pred_dir}') if pred_dir is not None: pred = load_predicted(pred_dir, self.combs, mtx=use_mtx) else: pred = self.predict() # Extract biological and technical embeddings and batch labels bio_embedding = pred['z']['joint'][:, :self.dim_c] # Biological embedding tech_embedding = pred['z']['joint'][:, self.dim_c:] # Technical embedding batch_labels = pred['s']['joint'].astype('int').astype('str') # Batch labels all_adata = [] # List to store AnnData objects all_figures = [] # List to store UMAP figures file_names = [] # List to store file names for UMAP plots file_names = ['biological_information.png', 'technical_information.png'] # File names for UMAP plotconds # Generate UMAP for both embeddings for index, (embedding, file_name) in enumerate(zip([bio_embedding, tech_embedding], file_names)): if file_name == 'biological_information.png' and drop_c_umap: logging.info('Skipping biological embedding UMAP generation as drop_c_umap is True.') continue if file_name == 'technical_information.png' and drop_u_umap: logging.info('Skipping technical embedding UMAP generation as drop_u_umap is True.') continue logging.info(f"Processing {'biological' if index == 0 else 'technical'} embedding...") # Create AnnData object for the embedding adata = sc.AnnData(embedding) adata.obs['batch'] = batch_labels # Compute nearest neighbors and UMAP logging.info(' - Computing neighbors...') sc.pp.neighbors(adata) logging.info(' - Computing UMAP...') sc.tl.umap(adata) # Plot UMAP and optionally save the figure logging.info(f' - Generating UMAP plot for {file_name}...') fig = sc.pl.umap(adata, title=file_name[:-4], color='batch', show=False, return_fig=True, **kwargs) all_figures.append(fig) if save_dir: fig_save_path = os.path.join(save_dir, 'figs', file_name) os.makedirs(os.path.dirname(fig_save_path), exist_ok=True) fig.savefig(fig_save_path) logging.info(f' - UMAP plot saved to: {fig_save_path}') if self.logger and self.viz_umap_tb: self.logger.experiment.add_figure(file_name, fig, self.current_epoch+1+self.start_epoch) all_adata.append(adata) logging.info('UMAP generation completed.') return all_adata, all_figures
[docs] def log_losses(self, recon_loss: torch.Tensor, kld_loss, consistency_loss: torch.Tensor, loss_net: torch.Tensor, loss_dsc: torch.Tensor, recon_dict: Dict[str, torch.Tensor]): """ Log losses for monitoring and debugging during training. Parameters: recon_loss : torch.Tensor Reconstruction loss. kld_loss : torch.Tensor KLD loss. consistency_loss : torch.Tensor Consistency loss. recon_dict : Dict[str, torch.Tensor] Per-modality reconstruction losses. loss_net : torch.Tensor Total VAE loss. loss_dsc : torch.Tensor Discriminator loss. """ self.log_dict( { 'loss_/recon_loss': recon_loss, 'loss_/kld_loss': kld_loss, 'loss_/consistency_loss': consistency_loss, 'loss/net': loss_net, 'loss/dsc':loss_dsc }, prog_bar=True, on_epoch=True, sync_dist=True, ) self.log_dict(recon_dict, on_epoch=True, sync_dist=True)
[docs] @staticmethod def update_model( loss: torch.Tensor, model: torch.nn.Module, optimizer: torch.optim.Optimizer, grad_clip: int=-1): """ Update model parameters using backpropagation. Parameters: loss : torch.Tensor Computed loss for backpropagation. model : torch.nn.Module Model to update. optimizer : torch.optim.Optimizer Optimizer for parameter updates. grad_clip : int True to allow clipping gradient. """ optimizer.zero_grad() loss.backward() if grad_clip > 0: nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step()
[docs] @staticmethod def calc_dsc_loss(pred: Dict[str, torch.Tensor], true: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculate the discriminator loss using cross-entropy. Parameters: pred : Dict[str, torch.Tensor] Predicted logits for each modality. true : Dict[str, torch.Tensor] Ground truth labels for each modality. Returns: torch.Tensor : Computed discriminator loss. """ cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum') # Cross-entropy loss loss = {} # Compute loss for each modality for modality in pred: loss[modality] = cross_entropy_loss(pred[modality], true[modality].squeeze(1)) # Normalize total loss by batch size total_loss = sum(loss.values()) / pred['joint'].size(0) return total_loss
[docs] @staticmethod def calc_kld_z_loss(dim_c: int, dim_u: int, lam_kld_c: float, lam_kld_u: float, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: """ Calculate the Kullback-Leibler Divergence (KLD) loss for latent variables z. Parameters: dim_c : int Dimension of the biological latent space. dim_u : int Dimension of the technical latent space. lam_kld_c : float Weight for KLD loss of the biological latent space. lam_kld_u : float Weight for KLD loss of the technical latent space. mu : torch.Tensor Mean of the latent variable distribution (batch_size x (dim_c + dim_u)). logvar : torch.Tensor Log-variance of the latent variable distribution (batch_size x (dim_c + dim_u)). Returns: torch.Tensor: Weighted sum of KLD losses for the biological and technical latent spaces. """ # Split the mean and log-variance into biological (c) and technical (u) components mu_c, mu_u = mu.split([dim_c, dim_u], dim=1) logvar_c, logvar_u = logvar.split([dim_c, dim_u], dim=1) # Calculate KLD losses for biological and technical latent spaces kld_c_loss = MIDAS.calc_kld_loss(mu_c, logvar_c) kld_u_loss = MIDAS.calc_kld_loss(mu_u, logvar_u) # Combine the losses with their respective weights kld_z_loss = kld_c_loss * lam_kld_c + kld_u_loss * lam_kld_u return kld_z_loss
[docs] @staticmethod def calc_kld_loss(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: """ Calculate the KLD loss for a single latent space. Parameters: mu : torch.Tensor Mean of the latent variable distribution (batch_size x latent_dim). logvar : torch.Tensor Log-variance of the latent variable distribution (batch_size x latent_dim). Returns: torch.Tensor : KLD loss for the latent space, normalized by batch size. """ # KLD loss formula: -0.5 * sum(1 + logvar - mu^2 - exp(logvar)) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.size(0) return kld_loss
[docs] @staticmethod def calc_consistency_loss(z_uni: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculate the consistency loss for unified latent variables across modalities. Parameters: z_uni : Dict[str, torch.Tensor] Dictionary of unified latent variables for each modality, where each value is a tensor of shape (batch_size x latent_dim). Returns: torch.Tensor : Consistency loss computed as the variance of the unified latent variables. """ # Stack the unified latent variables along a new dimension (modalities) z_uni_stack = torch.stack(list(z_uni.values()), dim=0) # Shape: M x N x K (M=modalities, N=batch_size, K=latent_dim) # Calculate the mean across modalities z_uni_mean = z_uni_stack.mean(0, keepdim=True) # Shape: 1 x N x K # Consistency loss is the variance across modalities consistency_loss = ((z_uni_stack - z_uni_mean) ** 2).sum() / z_uni_stack.size(1) # Normalize by batch size return consistency_loss
[docs] @staticmethod def calc_recon_loss( x: Dict[str, torch.Tensor], s: torch.Tensor, e: Dict[str, torch.Tensor], x_r_pre: Dict[str, torch.Tensor], s_r_pre: Dict[str, torch.Tensor], dist: Dict[str, str], lam: Dict[str, float] ) -> Tuple[float, Dict[torch.Tensor, torch.Tensor]]: """ Calculate the reconstruction loss for input data and predicted outputs. Parameters: x : Dict[str, torch.Tensor] Original input data for each modality (x^m). s : torch.Tensor Ground truth batch labels. e : Dict[str, torch.Tensor] Mask. x_r_pre : Dict[str, torch.Tensor] Reconstructed predictions for each modality (x_r^m). s_r_pre : Dict[str, torch.Tensor] Reconstructed predictions for batch labels. dist : Dict[str, str] Dictionary specifying the distribution type for each modality's decoder. lam : Dict[str, float] Dictionary containing reconstruction loss weights for each modality and for s. Returns: Tuple: - total_loss : torch.Tensor Total reconstruction loss, normalized by batch size. - losses : Dict[str, torch.Tensor] Dictionary containing reconstruction losses for each modality and for batch labels. """ losses = {} # Compute reconstruction loss for each modality for modality, x_original in x.items(): # Get the appropriate loss function based on the modality's decoder distribution loss_fn = distribution_registry.get_loss(dist[f'distribution_dec_{modality}']) # Check if there is an event-specific mask for the modality if modality in e: # Apply event-specific mask to the reconstruction loss logging.debug(f'recon {modality} {x_r_pre[modality]} {x_original}') losses[f'recon_loss/{modality}'] = ( loss_fn(x_r_pre[modality], x_original) * e[modality] ).sum() * lam[f'lam_recon_{modality}'] else: # Compute the reconstruction loss without a mask losses[f'recon_loss/{modality}'] = ( loss_fn(x_r_pre[modality], x_original) ).sum() * lam[f'lam_recon_{modality}'] # Compute reconstruction loss for batch labels, if provided if s_r_pre is not None: # Use cross-entropy loss for batch label reconstruction losses['recon_loss/s'] = ( distribution_registry.get_loss('CE')(s_r_pre, s.squeeze(1)) ).sum() * lam['lam_recon_s'] # Normalize total loss by the batch size total_loss = sum(losses.values()) / s.size(0) return total_loss, losses
[docs] @staticmethod def get_info_from_dir(dir_path: str, format: str): """ Extract data, mask, and feature dimensions from a directory of vectors. Parameters: dir_path : str Path to the directory containing data and mask files. format : str Support 'mtx', 'csv', and 'vec'. Returns: Tuple: - data : List[Dict[str, str]] List of dictionaries where keys are modalities and values are file paths. - mask : List[Dict[str, str]] List of dictionaries where keys are modalities and values are mask file paths. - dims_x : Dict[str, list] Dictionary containing feature dimensions for each modality. Notes: The directory should be organized as:: dataset/ feat/ # Dimensions of each modality: {mod1=[...], mod2=[...]}. # Split the data into chunks if the length of the list is greater than 1. # For instance, you can split the ATAC data by chromosomes. feat_dims.toml batch_0/ mask/mod1.csv mask/mod2.csv vec/mod1/0000.csv # the first sample vec/mod1/0001.csv # the second sample .... vec/mod2/0000.csv vec/mod2/0001.csv .... batch_1/ mask/mod1.csv mask/mod2.csv vec/mod1/0000.csv vec/mod1/0001.csv .... vec/mod2/0000.csv vec/mod2/0001.csv .... or like:: dataset/ feat/ # Dimensions of each modality: {mod1=[...], mod2=[...]}. # Split the data into chunks if the length of the list is greater than 1. # For instance, you can split the ATAC data by chromosomes. feat_dims.toml batch_0/ mask/mod1.csv mask/mod2.csv mat/mod1.mtx (.csv) mat/mod2.mtx (.csv) .... batch_1/ mask/mod1.csv mask/mod2.csv mat/mod1.mtx (.csv) mat/mod2.mtx (.csv) .... """ data = [] # List to store data file paths mask = [] # List to store mask file paths batch_names = [] for batch_dir in natsort.natsorted(os.listdir(dir_path)): if batch_dir != 'feat': # Ignore the 'feat' directory data_batch = {} mask_batch = {} batch_path = os.path.join(dir_path, batch_dir) batch_names.append(batch_dir) # Collect file paths for data and masks if format == 'vec': if os.path.exists(batch_path): vec_dir = os.path.join(batch_path, 'vec') mask_dir = os.path.join(batch_path, 'mask') for file in os.listdir(vec_dir): data_batch[file] = os.path.join(vec_dir, file) for file in os.listdir(mask_dir): mask_batch[file[:-4]] = os.path.join(mask_dir, file) elif format in ['csv', 'mtx']: if os.path.exists(batch_path): mat_dir = os.path.join(batch_path, 'mat') mask_dir = os.path.join(batch_path, 'mask') for file in os.listdir(mat_dir): data_batch[file[:-4]] = os.path.join(mat_dir, file) for file in os.listdir(mask_dir): mask_batch[file[:-4]] = os.path.join(mask_dir, file) data.append(data_batch) mask.append(mask_batch) # Load feature dimensions from 'feat_dims.toml' dims_x = toml.load(os.path.join(dir_path, 'feat', 'feat_dims.toml')) return data, mask, dims_x, batch_names
[docs] @classmethod def configure_new_data_from_dir(cls, configs: Dict[str, Any], dir_path: str, format: str = 'mtx', transform: Dict[str, str] = None, scale = None, viz_umap_tb: bool = False, **kwargs : Dict[str, Any]) -> 'MIDAS': # Extract data, mask, and feature dimensions from the directory data, mask, dims_x, batch_names = cls.get_info_from_dir(dir_path, format) # combine dims_x keys or keep the same cls.batch_names = batch_names if not hasattr(cls, 'batch_names') else cls.batch_names + batch_names cls.dims_x = dims_x if not hasattr(cls, 'dims_x') else {**cls.dims_x, **dims_x} logging.debug(f'Configured dims_x: {cls.dims_x}') # Configure datasets and associated parameters datalist, dims_s, s_joint, combs = cls.get_datasets_from_dir(data, mask, transform, format) cls.replay_batches = len(cls.datalist) - len(data) cls.current_batches = len(data) cls.scale = scale # merge transform transform = {**kwargs.get('transform',{}), **(transform if transform else {})} logging.debug(f'dims_s:{dims_s}, transform:{transform}, s_joint:{s_joint}, combs:{combs}') cls.start_epoch = 0 cls.load_optimizer_state = False # Finalize and return class instance return cls.configure_data( configs, datalist, cls.dims_x, dims_s, s_joint, combs, sampler_type='cl', viz_umap_tb=viz_umap_tb, batch_names=cls.batch_names, **kwargs)
[docs] @classmethod def configure_data_from_dir(cls, configs: Dict[str, Any], dir_path: str, format: str = 'mtx', transform: Dict[str, str] = None, sampler_type: str = 'auto', viz_umap_tb: bool = False, **kwargs : Dict[str, Any]) -> 'MIDAS': """ Configure data from a directory and apply optional transformations. Parameters: configs : Dict[str, Any] Configurations of the model. dir_path : str Path to the directory containing data files. transform : Dict[str, str], optional A dictionary specifying transformations to apply to specific modalities. Example: {'atac': 'binarize'} Default is None, which uses the default transformation settings. sampler_type : str, optional Type of sampler to use, by default 'auto'. For 'ddp', use distributed sampler. kwargs : Dict[str, Any] Additional parameters passed to configure_data(). Returns: class 'MIDAS': Returns the configured class instance. Examples: >>> from scmidas.model import MIDAS >>> from scmidas.config import load_config >>> configs = load_config() >>> dir_path = 'XXX' >>> transform = {'atac': 'binarize'} >>> model = MIDAS.configure_data_from_dir(configs, dir_path, transform) """ # Extract data, mask, and feature dimensions from the directory data, mask, dims_x, batch_names = cls.get_info_from_dir(dir_path, format) cls.dims_x = dims_x cls.batch_names = batch_names # Configure datasets and associated parameters datalist, dims_s, s_joint, combs = cls.get_datasets_from_dir(data, mask, transform, format) cls.start_epoch = 0 cls.load_optimizer_state = False # Finalize and return class instance return cls.configure_data( configs, datalist, dims_x, dims_s, s_joint, combs, sampler_type=sampler_type, viz_umap_tb=viz_umap_tb, batch_names=batch_names, **kwargs)
[docs] @classmethod def get_datasets_from_dir( cls, data: List[Dict[str, str]], mask: List[Dict[str, str]], transform: Dict[str, str]=None, format: str='mtx', ): """ Configure data from a CSV input. Parameters: data : List[Dict[str, str]] List of data dictionaries, where keys are modalities and values are file paths. mask : List[Dict[str, str]] List of mask dictionaries, where keys are modalities and values are mask file paths. transform : Optional[Dict[str, str]] Transformations to apply to specific modalities. format : str File type of the input data, default is 'vec'. ['vec', 'mtx', 'csv'] Returns: Tuple: - datasets : List[MultiModalDataset] List of initialized `MultiModalDataset` objects. - dims_s : Dict[str, int] Dimensions for batch correction for each modality. - s_joint : List[Dict[str, int]] Modality indices for each batch. - combs : List[List[str]] List of modality combinations for each batch. """ cls.s_joint = cls.s_joint if hasattr(cls, 's_joint') else [] # Modality indices for each batch cls.n_s = cls.n_s if hasattr(cls, 'n_s') else {'joint':-1} # Counter for each modality cls.combs = cls.combs if hasattr(cls, 'combs') else [] # Modality combinations for each batch if hasattr(cls, 'datasets'): cls.train_mod = 'continual' else: cls.train_mod = 'offline' cls.datasets = cls.datasets if hasattr(cls, 'datasets') else [] # List of datasets cls.dims_s = cls.dims_s if hasattr(cls, 'dims_s') else {} # Dimensions for batch correction for i, batch_data in enumerate(data): batch_s = {} # Store batch-specific indices batch_combs = [] # Modality combination for the current batch # Assign batch index for each modality for modality in batch_data.keys(): if modality in cls.n_s: batch_s[modality] = cls.n_s[modality] + 1 cls.n_s[modality] += 1 else: batch_s[modality] = 0 cls.n_s[modality] = 0 batch_combs.append(modality) # Add joint batch information cls.n_s['joint'] += 1 batch_s['joint'] = cls.n_s['joint'] cls.s_joint.append(batch_s) cls.combs.append(batch_combs) # Determine file types for each modality file_types = { modality: format for modality in batch_data.keys() } # Initialize MultiModalDataset dataset = MultiModalDataset(batch_data, batch_s, file_types, mask[i], transform) cls.datasets.append(dataset) # Define dimensions for batch correction cls.dims_s = {modality: count + 1 for modality, count in cls.n_s.items()} # concat mask cls.mask = cls.mask + mask if hasattr(cls, 'mask') else mask if cls.train_mod == 'continual': logging.debug(f'expected_dims{cls.dims_x}') for i, dataset in enumerate(cls.datasets[:-len(data)]): dataset_update = MultiModalDataset( dataset.mod_dict, dataset.mod_id_dict, dataset.file_type, dataset.mask_path, dataset.transform, expected_dims = cls.dims_x) cls.datasets[i] = dataset_update logging.debug(f'Update dataset for batch {i} with modalities') MIDAS.print_info(cls.mask, cls.datasets, cls.batch_names) return cls.datasets, cls.dims_s, cls.s_joint, cls.combs
[docs] @staticmethod @rank_zero_only def print_info(mask: List[Dict[str, str]], datalist: List[Dataset], batch_names: List[str]): """ Print summary of mask density and dataset information. Parameters: mask : List[Dict[str, str]] List of mask. datalist : List[Dataset] List of datasets. batch_name : List[str] List of batch names. """ # Calculate mask density for each batch feature = [] valid_feature = [] for i, dataset in enumerate(datalist): s1 = {} s2 = {} dataset = dataset[0] mask_ = mask[i] for m in dataset['x']: s1['#%s'%m.upper()] = len(dataset['x'][m]) if m in mask_: t = pd.read_csv(mask_[m], index_col=0).values s2['#VALID_'+m.upper()] = t.sum() feature.append(s1) valid_feature.append(s2) valid_feature = pd.DataFrame(valid_feature) valid_feature.index = batch_names cell_number = pd.DataFrame({'#CELL':[len(dataset) for dataset in datalist]}) cell_number.index = batch_names feature = pd.DataFrame(feature) feature.index = batch_names data = pd.concat([cell_number, feature, valid_feature], axis=1) # Print summary logging.info('Input data: \n' + data.to_string())
[docs] @classmethod def reset(cls): """ Reset class-level attributes to their initial states. """ attributes_to_reset = [ 'dims_x', 'dims_s', 's_joint', 'n_s', 'combs', 'batch_names', 'datasets', 'mask', 'train_mod', 'replay_batches', 'current_batches', 'datalist', 'scale' ] for attr in attributes_to_reset: if hasattr(cls, attr): delattr(cls, attr)