cellina.CellinaGCN#

class cellina.CellinaGCN(adata, n_hidden=128, n_latent=10, n_layers=2, discriminator_lambda=1.0, condition_on_intrinsic=False, link_prediction_weight=1.0, spatial_loss_type='supcon', classifier_lambda=1.0, supcon_temperature=0.25, num_neighbors=None, x_spatial_layer=None, use_observed_lib_size=True, convolution_type='gat', subgraph_type='induced', **model_kwargs)#

Cellina model with dual encoders for counts (MLP) and spatial context (GCN).

Extends scVI with a GCN spatial encoder that learns spatial aggregation via message passing over the spatial connectivity graph. The two latent representations (z from counts, s from GCN) are concatenated (shifted = concat(z, s)) and decoded to reconstruct count data.

Parameters:
  • adata (AnnData) – AnnData registered via setup_anndata().

  • n_hidden (int (default: 128)) – Nodes per hidden layer (shared by both encoders).

  • n_latent (int (default: 10)) – Latent dimensionality for both z and s.

  • n_layers (int (default: 2)) – Hidden layers (shared by both encoders).

  • discriminator_lambda (float (default: 1.0)) – Weight for adversarial domain forgetting. 0 disables it.

  • condition_on_intrinsic (bool (default: False)) – If True, concatenate detached z to GCN input before message passing.

  • link_prediction_weight (float (default: 1.0)) – Weight for spatial loss on s. 0 disables it.

  • spatial_loss_type (str (default: 'supcon')) – "supcon" (supervised contrastive, default) or "domain_clf".

  • classifier_lambda (float (default: 1.0)) – Weight for cell-type classifier loss.

  • supcon_temperature (float (default: 0.25)) – SupCon temperature.

  • num_neighbors (List[int] (default: None)) – Neighbors sampled per GCN layer; its length is the number of sampled hops and should equal n_layers. Default: None -> [-1] * n_layers (all neighbors at every hop). Any length other than n_layers (including length 1) emits a UserWarning and is used as-is, sampling that many hops.

  • x_spatial_layer (Optional[str] (default: None)) – Optional adata.layers key for alternative spatial features.

  • use_observed_lib_size (bool (default: True)) – Must be True (graph batches require observed library size).

  • convolution_type (str (default: 'gat')) – GCN type: "gcn", "gat", "gin", "sg".

  • **model_kwargs – Keyword args for CellinaGCNModule.

Examples

>>> CellinaGCN.setup_anndata(adata, batch_key="batch",
...     spatial_connectivities_key="spatial_connectivities")
>>> model = CellinaGCN(adata, n_latent=10)
>>> model.train()
>>> adata.obsm["X_cellina_gcn"] = model.get_latent_representation()

Methods table#

get_counterfactual_expression(indices, ...)

Predict gene expression under a counterfactual spatial neighbourhood.

get_counterfactual_latents(indices, ...[, ...])

Return latent representations under a counterfactual spatial neighbourhood.

get_latent_representation([adata, indices, ...])

Return the latent representation for each cell.

get_marginal_ll([adata, indices, ...])

Get marginal log-likelihood of the data.

get_normalized_expression([adata, indices, ...])

Return normalized expression.

get_perturbed_expression([adata, indices, ...])

Predict gene expression using counterfactual node features for the GCN.

get_perturbed_latents([adata, indices, ...])

Return latent representations using counterfactual node features for the GCN.

setup_anndata(adata[, batch_key, ...])

%(summary)s.

train([max_epochs, accelerator, devices, ...])

Train the model.

Methods#

get_counterfactual_expression#

CellinaGCN.get_counterfactual_expression(indices, neighbour_indices, n_neighbors_per_seed=20, batch_size=None, seed=0, library_size='latent', return_numpy=True, subgraph_type=None)#

Predict gene expression under a counterfactual spatial neighbourhood.

subgraph_type selects the counterfactual graph construction: None (default) inherits the model’s subgraph_type; 'directional' keeps only sampling-path edges (lower VRAM, output-equivalent for counterfactuals); 'induced' materialises the full induced subgraph (higher VRAM).

Return type:

ndarray

get_counterfactual_latents#

CellinaGCN.get_counterfactual_latents(indices, neighbour_indices, n_neighbors_per_seed=20, give_mean=False, batch_size=None, latent_key='s', seed=0, subgraph_type=None)#

Return latent representations under a counterfactual spatial neighbourhood.

Parameters:
  • indices (ndarray) – Cell indices to compute counterfactual latents for.

  • neighbour_indices (ndarray) – Donor neighbourhood pool indices.

  • n_neighbors_per_seed (int (default: 20)) – Donors per seed. Raises ValueError if >= len(neighbour_indices).

  • give_mean (bool (default: False)) – Return posterior mean rather than a sample.

  • batch_size (Optional[int] (default: None)) – Mini-batch size.

  • latent_key (str (default: 's')) – 'shifted', 'z', or 's'.

  • seed (int (default: 0)) – Random seed.

  • subgraph_type (Optional[str] (default: None)) – Counterfactual subgraph sampling mode. None (default) inherits the model’s subgraph_type; pass 'directional' to keep only sampling-path edges (lower VRAM, output-equivalent for counterfactuals) or 'induced' to materialise the full induced subgraph.

Return type:

ndarray

get_latent_representation#

CellinaGCN.get_latent_representation(adata=None, indices=None, give_mean=False, batch_size=None, latent_key='shifted')#

Return the latent representation for each cell.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData; defaults to training data.

  • indices (Optional[list] (default: None)) – Cell indices.

  • give_mean (bool (default: False)) – Return posterior mean.

  • batch_size (Optional[int] (default: None)) – Mini-batch size.

  • latent_key (Optional[str] (default: 'shifted')) – 'shifted', 'z', or 's'.

get_marginal_ll#

CellinaGCN.get_marginal_ll(adata=None, indices=None, batch_size=None, n_mc_samples=1000, return_mean=True)#

Get marginal log-likelihood of the data.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData to evaluate.

  • indices (Optional[list] (default: None)) – Cell indices.

  • batch_size (Optional[int] (default: None)) – Mini-batch size.

  • n_mc_samples (int (default: 1000)) – Monte Carlo importance-weighted samples per cell.

  • return_mean (bool (default: True)) – If True, return mean over all cells.

get_normalized_expression#

CellinaGCN.get_normalized_expression(adata=None, indices=None, batch_size=None, return_numpy=True, library_size='latent')#

Return normalized expression.

Parameters:

library_size (Union[float, str] (default: 'latent')) – "latent" (inferred), a float scalar, or 1 for pure proportions.

get_perturbed_expression#

CellinaGCN.get_perturbed_expression(adata=None, indices=None, batch_size=None, cf_layer='counts_cf', library_size='latent', return_numpy=True)#

Predict gene expression using counterfactual node features for the GCN.

Return type:

Union[ndarray, torch.Tensor]

get_perturbed_latents#

CellinaGCN.get_perturbed_latents(adata=None, indices=None, give_mean=False, batch_size=None, latent_key='s', cf_layer='counts_cf')#

Return latent representations using counterfactual node features for the GCN.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData; defaults to model’s adata.

  • indices (Optional[list] (default: None)) – Cell indices.

  • give_mean (bool (default: False)) – Return posterior mean.

  • batch_size (Optional[int] (default: None)) – Mini-batch size.

  • latent_key (str (default: 's')) – 'shifted', 'z', or 's'.

  • cf_layer (str (default: 'counts_cf')) – Key in adata.layers for counterfactual counts.

Return type:

ndarray

setup_anndata#

classmethod CellinaGCN.setup_anndata(adata, batch_key=None, labels_key=None, domains_key=None, layer=None, categorical_covariate_keys=None, continuous_covariate_keys=None, spatial_connectivities_key='spatial_connectivities', **kwargs)#

%(summary)s.

Parameters:
  • %(param_adata)s

  • %(param_batch_key)s

  • %(param_labels_key)s

  • domains_key (Optional[str] (default: None)) – Key in adata.obs for domain labels. Required if discriminator_lambda > 0.

  • %(param_layer)s

  • %(param_cat_cov_keys)s

  • %(param_cont_cov_keys)s

  • spatial_connectivities_key (str (default: 'spatial_connectivities')) – Key in adata.obsp for the spatial connectivity matrix.

Return type:

Optional[AnnData]

Returns:

: %(returns)s

train#

CellinaGCN.train(max_epochs=400, accelerator='auto', devices='auto', train_size=0.9, validation_size=None, shuffle_set_split=True, batch_size=128, datasplitter_kwargs=None, plan_kwargs=None, **kwargs)#

Train the model.

Parameters:
  • max_epochs (int (default: 400)) – Passes through the dataset.

  • accelerator (str (default: 'auto')) – Accelerator type.

  • devices (int | list[int] | str (default: 'auto')) – Devices to use.

  • train_size (float (default: 0.9)) – Training set fraction.

  • validation_size (float | None (default: None)) – Validation set size.

  • shuffle_set_split (bool (default: True)) – Shuffle before splitting.

  • batch_size (int (default: 128)) – Minibatch size.

  • datasplitter_kwargs (dict | None (default: None)) – Extra kwargs for the data splitter.

  • plan_kwargs (dict | None (default: None)) – Keyword args for training plan.