cellina.Cellina#
- class cellina.Cellina(adata, n_hidden=128, n_latent=10, n_layers=2, discriminator_lambda=1.0, condition_on_intrinsic=False, use_observed_lib_size=True, classifier_lambda=1.0, domain_classifier_lambda=0.0, **model_kwargs)#
Cellina model with dual encoders for counts and spatial data.
This model extends scVI with a spatial encoder that processes spatial features alongside the standard count encoder. The two latent representations (z from counts, s from spatial+z) are concatenated (shifted = concat(z, s)) and decoded together to reconstruct the count data.
- Parameters:
adata (
AnnData) – AnnData object that has been registered viasetup_anndata().n_hidden (
int(default:128)) – Number of nodes per hidden layer (shared by both encoders).n_latent (
int(default:10)) – Dimensionality of the latent space for both z and s encoders.n_layers (
int(default:2)) – Number of hidden layers (shared by both encoders).discriminator_lambda (
float(default:1.0)) – Weight for adversarial domain forgetting. Set to 0 (default) to disable.**model_kwargs – Keyword args for
CellinaModule
Examples
>>> adata = anndata.read_h5ad(path_to_anndata) >>> # adata.obsm["spatial_x"] should contain spatial features >>> Cellina.setup_anndata(adata, batch_key="batch", spatial_obsm_key="spatial_x") >>> model = Cellina(adata, n_latent=10) >>> model.train() >>> adata.obsm["X_cellina"] = model.get_latent_representation() # Returns shifted = concat(z, s) >>> adata.obsm["X_cellina_z"] = model.get_latent_representation(latent_key='z') >>> adata.obsm["X_cellina_s"] = model.get_latent_representation(latent_key='s')
Methods table#
|
Predict gene expression under a counterfactual spatial neighbourhood. |
|
Return latent representations under a counterfactual spatial neighbourhood. |
|
Return the latent representation for each cell. |
|
Get marginal log-likelihood of the data. |
|
Return normalized expression like scvi-tools. |
|
Return normalised expression using counterfactual spatial features. |
|
Return latent representation using counterfactual spatial features. |
|
%(summary)s. |
|
Train the model. |
Methods#
get_counterfactual_expression#
- Cellina.get_counterfactual_expression(indices, neighbour_indices, adata=None, batch_size=None, seed=0, library_size='latent', return_numpy=True, precomputed=True, n_neighbours=50, connectivity_key='spatial_connectivities')#
Predict gene expression under a counterfactual spatial neighbourhood.
Delegates to
get_normalized_expression()after building a counterfactual AnnData with rewired spatial features.- Parameters:
indices (
ndarray) – Cell indices to predict counterfactual expression for.neighbour_indices (
ndarray) – Indices of donor cells to sample spatial information from.adata (
Optional[AnnData] (default:None)) – Optional AnnData to use instead ofself.adata.batch_size (
Optional[int] (default:None)) – Minibatch size for the loader.seed (
int(default:0)) – Random seed for neighbour sampling.library_size (
Union[float,str] (default:'latent')) – Passed toget_normalized_expression(). Defaults to"latent"(uses inferred library size, returningpx_rate).return_numpy (
bool(default:True)) – Passed toget_normalized_expression().
- Return type:
- Returns:
: np.ndarray of shape
(len(indices), n_genes).
get_counterfactual_latents#
- Cellina.get_counterfactual_latents(indices, neighbour_indices, adata=None, give_mean=False, batch_size=None, latent_key='shifted', seed=0, precomputed=True, n_neighbours=50, connectivity_key='spatial_connectivities')#
Return latent representations under a counterfactual spatial neighbourhood.
Intrinsic
zis computed from each cell’s own counts (unchanged). Spatialsis computed via the neighbors ofneighbour_indicesinstead of their real spatial neighbours.- Parameters:
indices (
ndarray) – Cell indices to compute counterfactual latents for.neighbour_indices (
ndarray) – Indices of donor cells to sample spatial information from.adata (
Optional[AnnData] (default:None)) – Optional AnnData to use instead of self.adata for generating the counterfactual loader.give_mean (
bool(default:False)) – If True, return the mean of the latent distribution.batch_size (
Optional[int] (default:None)) – Minibatch size for the loader.latent_key (
str(default:'shifted')) – Which latent to return:'shifted','z', or's'.seed (
int(default:0)) – Random seed for neighbour sampling.
- Return type:
- Returns:
: np.ndarray of shape
(len(indices), n_latent)(or2*n_latentfor shifted).
get_latent_representation#
- Cellina.get_latent_representation(adata=None, indices=None, give_mean=False, batch_size=None, latent_key='shifted')#
Return the latent representation for each cell.
- Parameters:
adata (
Optional[AnnData] (default:None)) – AnnData object with equivalent structure to initial AnnData.indices (
Optional[list] (default:None)) – Indices of cells in adata to use.give_mean (
bool(default:False)) – If True, return the mean of the latent distribution. Otherwise, sample.batch_size (
Optional[int] (default:None)) – Minibatch size for data loading into model.latent_key (
Optional[str] (default:'shifted')) – Which latent representation to return. Options: ‘shifted’, ‘z’, ‘s’. Default: ‘shifted’ (returns concat(z, s), which is what the decoder uses).
- Returns:
: Latent representation for each cell as numpy array. - If latent_key is ‘shifted’: concat(z, s) (what goes into the decoder) - If latent_key is ‘z’: only z encoder output - If latent_key is ‘s’: only s encoder output
get_marginal_ll#
- Cellina.get_marginal_ll(adata=None, indices=None, batch_size=None, n_mc_samples=1000, return_mean=True)#
Get marginal log-likelihood of the data. …
get_normalized_expression#
- Cellina.get_normalized_expression(adata=None, indices=None, batch_size=None, return_numpy=True, library_size=1.0)#
Return normalized expression like scvi-tools.
- Parameters:
library_size (
Union[float,str] (default:1.0)) –float (e.g. 1e4): multiplies px_scale by this constant
1: returns px_scale (pure proportions)
”latent”: uses inferred latent library size (returns px_rate)
get_perturbed_expression#
- Cellina.get_perturbed_expression(adata=None, indices=None, batch_size=None, spatial_obsm_key='spatial_x_cf', library_size=1.0)#
Return normalised expression using counterfactual spatial features.
Temporarily swaps
adata.obsm[registered_spatial_key]withadata.obsm[spatial_obsm_key], runs inference and decoding, then restores the original data.- Parameters:
adata (
Optional[AnnData] (default:None)) – AnnData object; defaults to the model’s registered adata.indices (
Optional[list] (default:None)) – Cell indices to use.batch_size (
Optional[int] (default:None)) – Mini-batch size for inference.spatial_obsm_key (
str(default:'spatial_x_cf')) – Key inadata.obsmthat holds the counterfactual spatial features.library_size (
Union[float,str] (default:1.0)) – Passed directly toget_normalized_expression().
- Return type:
get_perturbed_latents#
- Cellina.get_perturbed_latents(adata=None, indices=None, give_mean=False, batch_size=None, latent_key='s', spatial_obsm_key='spatial_x_cf')#
Return latent representation using counterfactual spatial features.
Temporarily swaps
adata.obsm[registered_spatial_key]withadata.obsm[spatial_obsm_key], runs inference, then restores the original data.- Parameters:
adata (
Optional[AnnData] (default:None)) – AnnData object; defaults to the model’s registered adata.indices (
Optional[list] (default:None)) – Cell indices to use.give_mean (
bool(default:False)) – Return the mean of the posterior rather than a sample.batch_size (
Optional[int] (default:None)) – Mini-batch size for inference.latent_key (
Optional[str] (default:'s')) – Which latent to return:'shifted','z', or's'. Default is's'(the spatially-informed latent).spatial_obsm_key (
str(default:'spatial_x_cf')) – Key inadata.obsmthat holds the counterfactual spatial features (written bymake_neighbor_perturbation()).
- Return type:
setup_anndata#
- classmethod Cellina.setup_anndata(adata, spatial_obsm_key='spatial_x', batch_key=None, labels_key=None, domains_key=None, layer=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#
%(summary)s.
- Parameters:
%(param_adata)s
spatial_obsm_key (
str(default:'spatial_x')) – Key inadata.obsmcontaining spatial features matrix.%(param_batch_key)s
%(param_labels_key)s
domains_key (
Optional[str] (default:None)) – Key inadata.obsfor domain labels (categorical). Required if discriminator_lambda > 0.%(param_layer)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
- Return type:
Optional[AnnData]- Returns:
: %(returns)s
train#
- Cellina.train(max_epochs=400, accelerator='auto', devices='auto', train_size=0.9, validation_size=None, shuffle_set_split=True, batch_size=128, datasplitter_kwargs=None, plan_kwargs=None, **kwargs)#
Train the model.
- Parameters:
max_epochs (
int(default:400)) – Number of passes through the dataset.accelerator (
str(default:'auto')) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances.devices (
int|list[int] |str(default:'auto')) – The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value-1to indicate all available devices should be used, or"auto"for automatic selection based on the chosen accelerator.train_size (
float(default:0.9)) – Size of training set in the range [0.0, 1.0].validation_size (
float|None(default:None)) – Size of the validation set. IfNone, defaults to 1 -train_size. Iftrain_size + validation_size < 1, the remaining cells belong to a test set.shuffle_set_split (
bool(default:True)) – Whether to shuffle indices before splitting. IfFalse, the val, train, and test set are split in the sequential order of the data according tovalidation_sizeandtrain_sizepercentages.batch_size (
int(default:128)) – Minibatch size to use during training.datasplitter_kwargs (
dict|None(default:None)) – Additional keyword arguments passed intoDataSplitter.plan_kwargs (
dict|None(default:None)) – Keyword args forCellinaAdversarialTrainingPlan. Keyword arguments passed totrain()will overwrite values present inplan_kwargs, when appropriate.**kwargs – Other keyword args for
Trainer.