cellina.Cellina#

class cellina.Cellina(adata, n_hidden=128, n_latent=10, n_layers=2, discriminator_lambda=1.0, condition_on_intrinsic=False, use_observed_lib_size=True, classifier_lambda=1.0, domain_classifier_lambda=0.0, **model_kwargs)#

Cellina model with dual encoders for counts and spatial data.

This model extends scVI with a spatial encoder that processes spatial features alongside the standard count encoder. The two latent representations (z from counts, s from spatial+z) are concatenated (shifted = concat(z, s)) and decoded together to reconstruct the count data.

Parameters:
  • adata (AnnData) – AnnData object that has been registered via setup_anndata().

  • n_hidden (int (default: 128)) – Number of nodes per hidden layer (shared by both encoders).

  • n_latent (int (default: 10)) – Dimensionality of the latent space for both z and s encoders.

  • n_layers (int (default: 2)) – Number of hidden layers (shared by both encoders).

  • discriminator_lambda (float (default: 1.0)) – Weight for adversarial domain forgetting. Set to 0 (default) to disable.

  • **model_kwargs – Keyword args for CellinaModule

Examples

>>> adata = anndata.read_h5ad(path_to_anndata)
>>> # adata.obsm["spatial_x"] should contain spatial features
>>> Cellina.setup_anndata(adata, batch_key="batch", spatial_obsm_key="spatial_x")
>>> model = Cellina(adata, n_latent=10)
>>> model.train()
>>> adata.obsm["X_cellina"] = model.get_latent_representation()  # Returns shifted = concat(z, s)
>>> adata.obsm["X_cellina_z"] = model.get_latent_representation(latent_key='z')
>>> adata.obsm["X_cellina_s"] = model.get_latent_representation(latent_key='s')

Methods table#

get_counterfactual_expression(indices, ...)

Predict gene expression under a counterfactual spatial neighbourhood.

get_counterfactual_latents(indices, ...[, ...])

Return latent representations under a counterfactual spatial neighbourhood.

get_latent_representation([adata, indices, ...])

Return the latent representation for each cell.

get_marginal_ll([adata, indices, ...])

Get marginal log-likelihood of the data.

get_normalized_expression([adata, indices, ...])

Return normalized expression like scvi-tools.

get_perturbed_expression([adata, indices, ...])

Return normalised expression using counterfactual spatial features.

get_perturbed_latents([adata, indices, ...])

Return latent representation using counterfactual spatial features.

setup_anndata(adata[, spatial_obsm_key, ...])

%(summary)s.

train([max_epochs, accelerator, devices, ...])

Train the model.

Methods#

get_counterfactual_expression#

Cellina.get_counterfactual_expression(indices, neighbour_indices, adata=None, batch_size=None, seed=0, library_size='latent', return_numpy=True, precomputed=True, n_neighbours=50, connectivity_key='spatial_connectivities')#

Predict gene expression under a counterfactual spatial neighbourhood.

Delegates to get_normalized_expression() after building a counterfactual AnnData with rewired spatial features.

Parameters:
  • indices (ndarray) – Cell indices to predict counterfactual expression for.

  • neighbour_indices (ndarray) – Indices of donor cells to sample spatial information from.

  • adata (Optional[AnnData] (default: None)) – Optional AnnData to use instead of self.adata.

  • batch_size (Optional[int] (default: None)) – Minibatch size for the loader.

  • seed (int (default: 0)) – Random seed for neighbour sampling.

  • library_size (Union[float, str] (default: 'latent')) – Passed to get_normalized_expression(). Defaults to "latent" (uses inferred library size, returning px_rate).

  • return_numpy (bool (default: True)) – Passed to get_normalized_expression().

Return type:

ndarray

Returns:

: np.ndarray of shape (len(indices), n_genes).

get_counterfactual_latents#

Cellina.get_counterfactual_latents(indices, neighbour_indices, adata=None, give_mean=False, batch_size=None, latent_key='shifted', seed=0, precomputed=True, n_neighbours=50, connectivity_key='spatial_connectivities')#

Return latent representations under a counterfactual spatial neighbourhood.

Intrinsic z is computed from each cell’s own counts (unchanged). Spatial s is computed via the neighbors of neighbour_indices instead of their real spatial neighbours.

Parameters:
  • indices (ndarray) – Cell indices to compute counterfactual latents for.

  • neighbour_indices (ndarray) – Indices of donor cells to sample spatial information from.

  • adata (Optional[AnnData] (default: None)) – Optional AnnData to use instead of self.adata for generating the counterfactual loader.

  • give_mean (bool (default: False)) – If True, return the mean of the latent distribution.

  • batch_size (Optional[int] (default: None)) – Minibatch size for the loader.

  • latent_key (str (default: 'shifted')) – Which latent to return: 'shifted', 'z', or 's'.

  • seed (int (default: 0)) – Random seed for neighbour sampling.

Return type:

ndarray

Returns:

: np.ndarray of shape (len(indices), n_latent) (or 2*n_latent for shifted).

get_latent_representation#

Cellina.get_latent_representation(adata=None, indices=None, give_mean=False, batch_size=None, latent_key='shifted')#

Return the latent representation for each cell.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData object with equivalent structure to initial AnnData.

  • indices (Optional[list] (default: None)) – Indices of cells in adata to use.

  • give_mean (bool (default: False)) – If True, return the mean of the latent distribution. Otherwise, sample.

  • batch_size (Optional[int] (default: None)) – Minibatch size for data loading into model.

  • latent_key (Optional[str] (default: 'shifted')) – Which latent representation to return. Options: ‘shifted’, ‘z’, ‘s’. Default: ‘shifted’ (returns concat(z, s), which is what the decoder uses).

Returns:

: Latent representation for each cell as numpy array. - If latent_key is ‘shifted’: concat(z, s) (what goes into the decoder) - If latent_key is ‘z’: only z encoder output - If latent_key is ‘s’: only s encoder output

get_marginal_ll#

Cellina.get_marginal_ll(adata=None, indices=None, batch_size=None, n_mc_samples=1000, return_mean=True)#

Get marginal log-likelihood of the data. …

get_normalized_expression#

Cellina.get_normalized_expression(adata=None, indices=None, batch_size=None, return_numpy=True, library_size=1.0)#

Return normalized expression like scvi-tools.

Parameters:

library_size (Union[float, str] (default: 1.0)) –

  • float (e.g. 1e4): multiplies px_scale by this constant

  • 1: returns px_scale (pure proportions)

  • ”latent”: uses inferred latent library size (returns px_rate)

get_perturbed_expression#

Cellina.get_perturbed_expression(adata=None, indices=None, batch_size=None, spatial_obsm_key='spatial_x_cf', library_size=1.0)#

Return normalised expression using counterfactual spatial features.

Temporarily swaps adata.obsm[registered_spatial_key] with adata.obsm[spatial_obsm_key], runs inference and decoding, then restores the original data.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData object; defaults to the model’s registered adata.

  • indices (Optional[list] (default: None)) – Cell indices to use.

  • batch_size (Optional[int] (default: None)) – Mini-batch size for inference.

  • spatial_obsm_key (str (default: 'spatial_x_cf')) – Key in adata.obsm that holds the counterfactual spatial features.

  • library_size (Union[float, str] (default: 1.0)) – Passed directly to get_normalized_expression().

Return type:

ndarray

get_perturbed_latents#

Cellina.get_perturbed_latents(adata=None, indices=None, give_mean=False, batch_size=None, latent_key='s', spatial_obsm_key='spatial_x_cf')#

Return latent representation using counterfactual spatial features.

Temporarily swaps adata.obsm[registered_spatial_key] with adata.obsm[spatial_obsm_key], runs inference, then restores the original data.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData object; defaults to the model’s registered adata.

  • indices (Optional[list] (default: None)) – Cell indices to use.

  • give_mean (bool (default: False)) – Return the mean of the posterior rather than a sample.

  • batch_size (Optional[int] (default: None)) – Mini-batch size for inference.

  • latent_key (Optional[str] (default: 's')) – Which latent to return: 'shifted', 'z', or 's'. Default is 's' (the spatially-informed latent).

  • spatial_obsm_key (str (default: 'spatial_x_cf')) – Key in adata.obsm that holds the counterfactual spatial features (written by make_neighbor_perturbation()).

Return type:

ndarray

setup_anndata#

classmethod Cellina.setup_anndata(adata, spatial_obsm_key='spatial_x', batch_key=None, labels_key=None, domains_key=None, layer=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#

%(summary)s.

Parameters:
  • %(param_adata)s

  • spatial_obsm_key (str (default: 'spatial_x')) – Key in adata.obsm containing spatial features matrix.

  • %(param_batch_key)s

  • %(param_labels_key)s

  • domains_key (Optional[str] (default: None)) – Key in adata.obs for domain labels (categorical). Required if discriminator_lambda > 0.

  • %(param_layer)s

  • %(param_cat_cov_keys)s

  • %(param_cont_cov_keys)s

Return type:

Optional[AnnData]

Returns:

: %(returns)s

train#

Cellina.train(max_epochs=400, accelerator='auto', devices='auto', train_size=0.9, validation_size=None, shuffle_set_split=True, batch_size=128, datasplitter_kwargs=None, plan_kwargs=None, **kwargs)#

Train the model.

Parameters:
  • max_epochs (int (default: 400)) – Number of passes through the dataset.

  • accelerator (str (default: 'auto')) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances.

  • devices (int | list[int] | str (default: 'auto')) – The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator.

  • train_size (float (default: 0.9)) – Size of training set in the range [0.0, 1.0].

  • validation_size (float | None (default: None)) – Size of the validation set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.

  • shuffle_set_split (bool (default: True)) – Whether to shuffle indices before splitting. If False, the val, train, and test set are split in the sequential order of the data according to validation_size and train_size percentages.

  • batch_size (int (default: 128)) – Minibatch size to use during training.

  • datasplitter_kwargs (dict | None (default: None)) – Additional keyword arguments passed into DataSplitter.

  • plan_kwargs (dict | None (default: None)) – Keyword args for CellinaAdversarialTrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.

  • **kwargs – Other keyword args for Trainer.