import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy
[docs]
def plot_umap(
adata,
key='z_c_joint',
do_pca=False,
n_comps=32,
color='batch',
**kwargs
):
"""
Computes and plots a UMAP for a single AnnData object based on a specific latent representation.
This function allows for optional PCA preprocessing on the selected latent representation
(stored in .obsm) before computing the neighborhood graph and UMAP.
Args:
adata (AnnData): The input annotated data matrix.
key (str, optional): The key in `adata.obsm` to use as the representation.
Defaults to 'z_c_joint'.
do_pca (bool, optional): Whether to perform scaling and PCA on the representation
before neighbor calculation. Defaults to False.
n_comps (int, optional): The number of principal components to use if `do_pca` is True.
Defaults to 32.
color (str, optional): The key in `adata.obs` used to color the plot.
Defaults to 'batch'.
**kwargs: Additional keyword arguments passed to `sc.pl.umap`.
Returns:
None: Displays the plot.
"""
if do_pca:
adata2 = sc.AnnData(adata.obsm[key])
adata2.obs = adata.obs
sc.pp.scale(adata2)
sc.pp.pca(adata2, n_comps=n_comps)
key = 'X_pca'
else:
adata2 = copy.deepcopy(adata)
sc.pp.neighbors(adata2, use_rep=key)
sc.tl.umap(adata2)
sc.pl.umap(adata2, color=color, **kwargs)
[docs]
def plot_umap_grid(adata, axis1, axis2, color, figsize=2, point_size=2, fontsize=10, background=True):
"""
Plots a grid (facet plot) of UMAPs split by two categorical variables.
This visualizes how specific groups (defined by axis1 and axis2) are distributed
within the global UMAP space.
Args:
adata (AnnData): Annotated data matrix with pre-computed UMAP coordinates (`X_umap`).
axis1 (str): Key in `adata.obs` defining the rows of the grid.
axis2 (str): Key in `adata.obs` defining the columns of the grid.
color (str): Key in `adata.obs` used for coloring the points.
figsize (float, optional): The size (in inches) of each subplot. Defaults to 2.
point_size (float, optional): The size of the scatter points. Defaults to 2.
fontsize (int, optional): Font size for the legend. Defaults to 10.
background (bool, optional): If True, plots all cells in grey in the background
of each subplot to show the global structure. Defaults to True.
Returns:
None: Displays the plot.
"""
axis1_names = adata.obs[axis1].unique()
axis2_names = adata.obs[axis2].unique()
nrows = len(axis1_names)
ncols = len(axis2_names)
fig, ax = plt.subplots(nrows, ncols, figsize=[figsize * ncols, figsize * nrows])
fig_dummy, ax_dummy = plt.subplots()
sc.pl.umap(adata, color=color, show=False, ax=ax_dummy)
handles, labels_ = ax_dummy.get_legend_handles_labels()
plt.close(fig_dummy)
for i, k1 in enumerate(axis1_names):
for j, k2 in enumerate(axis2_names):
if background:
sc.pl.umap(adata, show=False, ax=ax[i, j], s=point_size) # background
sc.pl.umap(adata[(adata.obs[axis1]==k1) & (adata.obs[axis2]==k2)], color=color, show=False, ax=ax[i, j], s=point_size)
ax[i, j].get_legend().set_visible(False)
ax[i, j].set_xticks([])
ax[i, j].set_yticks([])
ax[i, j].set_xlabel('')
if j==0:
ax[i, j].set_ylabel(k1)
else:
ax[i, j].set_ylabel('')
if i==0:
ax[i, j].set_title(k2)
else:
ax[i, j].set_title('')
# create global legend
fig.legend(handles, labels_, loc='center',
bbox_to_anchor=(0.5, -0.02), ncol=len(labels_), fontsize=fontsize)
# adjust the figure
plt.tight_layout(rect=[0.1, 0.05, 1, 1])
plt.show()
[docs]
def plot_z_umap_grid(adata_list, batch_col='batch', color='label', figsize=2, point_size=2, fontsize=10, transpose=False):
"""
Aggregates latent representations from a dictionary of AnnData objects, computes a joint UMAP,
and plots a grid view.
It specifically looks for keys in `.obsm` starting with 'z_c', concatenates them,
and re-computes the UMAP to visualize the alignment or distribution across different batches/types.
Args:
adata_list (dict): A dictionary where keys are batch identifiers and values are AnnData objects.
batch_col (str, optional): Key in `adata.obs` identifying the batch/sample. Defaults to 'batch'.
color (str, optional): Key in `adata.obs` used for coloring. Defaults to 'label'.
figsize (float, optional): The size (in inches) of each subplot. Defaults to 2.
point_size (float, optional): The size of the scatter points. Defaults to 2.
fontsize (int, optional): Font size for the legend. Defaults to 10.
transpose (bool, optional): If True, swaps the row and column axes of the grid
(Batch vs. Type). Defaults to False.
Returns:
None: Displays the plot.
"""
data = []
axis1_ = []
axis2_ = []
label_ = []
for b, adata in adata_list.items():
for k in adata.obsm:
if k.startswith('z_c'):
data.append(adata.obsm[k])
axis1_.append(adata.obs[batch_col])
axis2_.append([k.split('_')[-1].upper() for i in range(len(adata))])
label_.append(adata.obs[color])
data = np.concatenate(data)
axis1_ = np.concatenate(axis1_)
axis2_ = np.concatenate(axis2_)
label_ = np.concatenate(label_)
adata = sc.AnnData(data)
adata.obs['batch'] = axis1_
adata.obs['type'] = axis2_
adata.obs[color] = label_
sc.pp.neighbors(adata)
sc.tl.umap(adata)
axis1 = 'batch' if not transpose else 'type'
axis2 = 'type' if not transpose else 'batch'
axis1_names = adata.obs[axis1].unique()
axis2_names = adata.obs[axis2].unique()
nrows = len(axis1_names)
ncols = len(axis2_names)
fig, ax = plt.subplots(nrows, ncols, figsize=[figsize * ncols, figsize * nrows])
fig_dummy, ax_dummy = plt.subplots()
sc.pl.umap(adata, color=color, show=False, ax=ax_dummy)
handles, labels_ = ax_dummy.get_legend_handles_labels()
plt.close(fig_dummy)
for i, k1 in enumerate(axis1_names):
for j, k2 in enumerate(axis2_names):
sc.pl.umap(adata, show=False, ax=ax[i, j], s=point_size) # background
sc.pl.umap(adata[(adata.obs[axis1]==k1) & (adata.obs[axis2]==k2)], color=color, show=False, ax=ax[i, j], s=point_size)
ax[i, j].get_legend().set_visible(False)
ax[i, j].set_xticks([])
ax[i, j].set_yticks([])
ax[i, j].set_xlabel('')
if j==0:
ax[i, j].set_ylabel(k1)
else:
ax[i, j].set_ylabel('')
if i==0:
ax[i, j].set_title(k2)
else:
ax[i, j].set_title('')
# create global legend
fig.legend(handles, labels_, loc='center',
bbox_to_anchor=(0.5, -0.02), ncol=len(labels_), fontsize=fontsize)
# adjust the figure
plt.tight_layout(rect=[0.1, 0.05, 1, 1])
plt.show()