cellina.CellinaGCN#
- class cellina.CellinaGCN(adata, n_hidden=128, n_latent=10, n_layers=2, discriminator_lambda=1.0, condition_on_intrinsic=False, link_prediction_weight=1.0, spatial_loss_type='supcon', classifier_lambda=1.0, supcon_temperature=0.25, num_neighbors=None, x_spatial_layer=None, use_observed_lib_size=True, convolution_type='gat', subgraph_type='induced', **model_kwargs)#
Cellina model with dual encoders for counts (MLP) and spatial context (GCN).
Extends scVI with a GCN spatial encoder that learns spatial aggregation via message passing over the spatial connectivity graph. The two latent representations (z from counts, s from GCN) are concatenated (shifted = concat(z, s)) and decoded to reconstruct count data.
- Parameters:
adata (
AnnData) – AnnData registered viasetup_anndata().n_hidden (
int(default:128)) – Nodes per hidden layer (shared by both encoders).n_latent (
int(default:10)) – Latent dimensionality for both z and s.n_layers (
int(default:2)) – Hidden layers (shared by both encoders).discriminator_lambda (
float(default:1.0)) – Weight for adversarial domain forgetting. 0 disables it.condition_on_intrinsic (
bool(default:False)) – If True, concatenate detached z to GCN input before message passing.link_prediction_weight (
float(default:1.0)) – Weight for spatial loss on s. 0 disables it.spatial_loss_type (
str(default:'supcon')) –"supcon"(supervised contrastive, default) or"domain_clf".classifier_lambda (
float(default:1.0)) – Weight for cell-type classifier loss.supcon_temperature (
float(default:0.25)) – SupCon temperature.num_neighbors (
List[int] (default:None)) – Neighbors sampled per GCN layer; its length is the number of sampled hops and should equaln_layers. Default:None->[-1] * n_layers(all neighbors at every hop). Any length other thann_layers(including length 1) emits aUserWarningand is used as-is, sampling that many hops.x_spatial_layer (
Optional[str] (default:None)) – Optionaladata.layerskey for alternative spatial features.use_observed_lib_size (
bool(default:True)) – Must be True (graph batches require observed library size).convolution_type (
str(default:'gat')) – GCN type:"gcn","gat","gin","sg".**model_kwargs – Keyword args for
CellinaGCNModule.
Examples
>>> CellinaGCN.setup_anndata(adata, batch_key="batch", ... spatial_connectivities_key="spatial_connectivities") >>> model = CellinaGCN(adata, n_latent=10) >>> model.train() >>> adata.obsm["X_cellina_gcn"] = model.get_latent_representation()
Methods table#
|
Predict gene expression under a counterfactual spatial neighbourhood. |
|
Return latent representations under a counterfactual spatial neighbourhood. |
|
Return the latent representation for each cell. |
|
Get marginal log-likelihood of the data. |
|
Return normalized expression. |
|
Predict gene expression using counterfactual node features for the GCN. |
|
Return latent representations using counterfactual node features for the GCN. |
|
%(summary)s. |
|
Train the model. |
Methods#
get_counterfactual_expression#
- CellinaGCN.get_counterfactual_expression(indices, neighbour_indices, n_neighbors_per_seed=20, batch_size=None, seed=0, library_size='latent', return_numpy=True, subgraph_type=None)#
Predict gene expression under a counterfactual spatial neighbourhood.
subgraph_typeselects the counterfactual graph construction:None(default) inherits the model’ssubgraph_type;'directional'keeps only sampling-path edges (lower VRAM, output-equivalent for counterfactuals);'induced'materialises the full induced subgraph (higher VRAM).- Return type:
get_counterfactual_latents#
- CellinaGCN.get_counterfactual_latents(indices, neighbour_indices, n_neighbors_per_seed=20, give_mean=False, batch_size=None, latent_key='s', seed=0, subgraph_type=None)#
Return latent representations under a counterfactual spatial neighbourhood.
- Parameters:
indices (
ndarray) – Cell indices to compute counterfactual latents for.neighbour_indices (
ndarray) – Donor neighbourhood pool indices.n_neighbors_per_seed (
int(default:20)) – Donors per seed. Raises ValueError if >= len(neighbour_indices).give_mean (
bool(default:False)) – Return posterior mean rather than a sample.batch_size (
Optional[int] (default:None)) – Mini-batch size.latent_key (
str(default:'s')) –'shifted','z', or's'.seed (
int(default:0)) – Random seed.subgraph_type (
Optional[str] (default:None)) – Counterfactual subgraph sampling mode.None(default) inherits the model’ssubgraph_type; pass'directional'to keep only sampling-path edges (lower VRAM, output-equivalent for counterfactuals) or'induced'to materialise the full induced subgraph.
- Return type:
get_latent_representation#
- CellinaGCN.get_latent_representation(adata=None, indices=None, give_mean=False, batch_size=None, latent_key='shifted')#
Return the latent representation for each cell.
- Parameters:
adata (
Optional[AnnData] (default:None)) – AnnData; defaults to training data.indices (
Optional[list] (default:None)) – Cell indices.give_mean (
bool(default:False)) – Return posterior mean.batch_size (
Optional[int] (default:None)) – Mini-batch size.latent_key (
Optional[str] (default:'shifted')) –'shifted','z', or's'.
get_marginal_ll#
- CellinaGCN.get_marginal_ll(adata=None, indices=None, batch_size=None, n_mc_samples=1000, return_mean=True)#
Get marginal log-likelihood of the data.
- Parameters:
adata (
Optional[AnnData] (default:None)) – AnnData to evaluate.indices (
Optional[list] (default:None)) – Cell indices.batch_size (
Optional[int] (default:None)) – Mini-batch size.n_mc_samples (
int(default:1000)) – Monte Carlo importance-weighted samples per cell.return_mean (
bool(default:True)) – If True, return mean over all cells.
get_normalized_expression#
- CellinaGCN.get_normalized_expression(adata=None, indices=None, batch_size=None, return_numpy=True, library_size='latent')#
Return normalized expression.
- Parameters:
library_size (
Union[float,str] (default:'latent')) –"latent"(inferred), a float scalar, or1for pure proportions.
get_perturbed_expression#
get_perturbed_latents#
- CellinaGCN.get_perturbed_latents(adata=None, indices=None, give_mean=False, batch_size=None, latent_key='s', cf_layer='counts_cf')#
Return latent representations using counterfactual node features for the GCN.
- Parameters:
adata (
Optional[AnnData] (default:None)) – AnnData; defaults to model’s adata.indices (
Optional[list] (default:None)) – Cell indices.give_mean (
bool(default:False)) – Return posterior mean.batch_size (
Optional[int] (default:None)) – Mini-batch size.latent_key (
str(default:'s')) –'shifted','z', or's'.cf_layer (
str(default:'counts_cf')) – Key inadata.layersfor counterfactual counts.
- Return type:
setup_anndata#
- classmethod CellinaGCN.setup_anndata(adata, batch_key=None, labels_key=None, domains_key=None, layer=None, categorical_covariate_keys=None, continuous_covariate_keys=None, spatial_connectivities_key='spatial_connectivities', **kwargs)#
%(summary)s.
- Parameters:
%(param_adata)s
%(param_batch_key)s
%(param_labels_key)s
domains_key (
Optional[str] (default:None)) – Key inadata.obsfor domain labels. Required ifdiscriminator_lambda > 0.%(param_layer)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
spatial_connectivities_key (
str(default:'spatial_connectivities')) – Key inadata.obspfor the spatial connectivity matrix.
- Return type:
Optional[AnnData]- Returns:
: %(returns)s
train#
- CellinaGCN.train(max_epochs=400, accelerator='auto', devices='auto', train_size=0.9, validation_size=None, shuffle_set_split=True, batch_size=128, datasplitter_kwargs=None, plan_kwargs=None, **kwargs)#
Train the model.
- Parameters:
max_epochs (
int(default:400)) – Passes through the dataset.accelerator (
str(default:'auto')) – Accelerator type.devices (
int|list[int] |str(default:'auto')) – Devices to use.train_size (
float(default:0.9)) – Training set fraction.validation_size (
float|None(default:None)) – Validation set size.shuffle_set_split (
bool(default:True)) – Shuffle before splitting.batch_size (
int(default:128)) – Minibatch size.datasplitter_kwargs (
dict|None(default:None)) – Extra kwargs for the data splitter.plan_kwargs (
dict|None(default:None)) – Keyword args for training plan.