import abc
import time
import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ncem.utils.losses import GaussianLoss, KLLoss, NegBinLoss
from ncem.utils.metrics import (custom_kl, custom_mae, custom_mean_sd,
custom_mse, custom_mse_scaled,
gaussian_reconstruction_loss, logp1_custom_mse,
logp1_r_squared, logp1_r_squared_linreg,
nb_reconstruction_loss, r_squared,
r_squared_linreg)
def transfer_layers(model1, model2):
"""
Transfer layer weights from model 1 to model 2.
:param model1: Input model.
:param model2: Output model.
"""
layer_names_model1 = [x.name for x in model1.layers]
layer_names_model2 = [x.name for x in model2.layers]
layers_updated = []
layer_not_updated = set(layer_names_model2)
for x in layer_names_model1:
w = model1.get_layer(name=x).get_weights()
if x in layer_names_model2:
# Only update layers with parameters:
if len(w) > 0:
model2.get_layer(x).set_weights(w)
layers_updated.append(x)
layer_not_updated = layer_not_updated.difference({x})
print(f"updated layers: {layers_updated}")
print(f"did not update layers: {layer_not_updated}")
[docs]class Estimator:
"""Estimator class for models.
Contains all necessary methods for data loading, model initialization, training, evaluation and prediction.
"""
img_to_patient_dict: Dict[str, str]
complete_img_keys: List[str]
a: dict # dict of adjacency matrices of shape (max_nodes, max_nodes)
h_0: Dict[str, np.ndarray] # dict of adjacency matrices of shape (max_nodes, n_features_0)
h_1: Dict[str, np.ndarray] # dict of adjacency matrices of shape (max_nodes, n_features_1)
size_factors: Dict[str, np.ndarray]
graph_covar: Dict[str, np.ndarray]
node_covar: Dict[str, np.ndarray]
domains: Dict[str, np.ndarray]
covar_selection: Union[List[str], Tuple[str], None]
node_types: Dict[str, np.ndarray]
node_type_names: Dict[str, str]
graph_covar_names: Dict[str, List[str]]
node_feature_names: List[str]
n_features_type: int
n_features_standard: int
n_features_0: int
n_features_1: int
n_graph_covariates: int
n_node_covariates: int
n_domains: int
max_nodes: int
n_eval_nodes_per_graph: int
vi_model: bool
log_transform: bool
model_type: str
adj_type: str
cond_type: str
cond_depth: int
output_layer: str
img_keys_test = list
img_keys_eval = list
img_keys_train = list
nodes_idx_test = Dict[str, list]
nodes_idx_eval = Dict[str, list]
nodes_idx_train = Dict[str, list]
steps_per_epoch: int
validation_steps: int
def __init__(self):
"""Initialize Estimator class."""
self.model = None
self.loss = []
self.metrics = []
self.optimizer = None
self.beta = None
self.max_beta = None
self.pre_warm_up = None
self.train_hyperparam = {}
self.history = {}
self.pretrain_history = {}
self.train_dataset = None
self.eval_dataset = None
def _load_data(
self,
data_origin: str,
data_path: str,
radius: Optional[int] = None,
n_rings: int = 1,
label_selection: Optional[List[str]] = None,
n_top_genes: Optional[int] = None
):
"""Initialize a DataLoader object.
Parameters
----------
data_origin : str
Data origin.
data_path : str
Data path.
radius : int
Radius.
label_selection : list, optional
Label selection.
n_top_genes: int, optional
N top genes for highly variable gene selection.
Raises
------
ValueError
If `data_origin` not recognized.
"""
coord_type = 'generic'
self.targeted_assay = True
if data_origin.startswith("zhang"):
from ncem.data import DataLoaderZhang as DataLoader
self.undefined_node_types = ["other"]
elif data_origin.startswith("jarosch"):
from ncem.data import DataLoaderJarosch as DataLoader
self.undefined_node_types = None
elif data_origin.startswith("hartmann"):
from ncem.data import DataLoaderHartmann as DataLoader
self.undefined_node_types = None
elif data_origin == "pascualreguant":
from ncem.data import DataLoaderPascualReguant as DataLoader
self.undefined_node_types = ["other"]
elif data_origin.startswith("schuerch"):
from ncem.data import DataLoaderSchuerch as DataLoader
self.undefined_node_types = [
"dirt",
"undefined",
"tumor cells / immune cells",
"immune cells / vasculature",
]
elif data_origin.startswith('lohoff'):
from ncem.data import DataLoaderLohoff as DataLoader
self.undefined_node_types = ['Low quality']
elif data_origin.startswith("luwt"):
if data_origin == "luwt_imputation":
from ncem.data import DataLoaderLuWTimputed as DataLoader
else:
from ncem.data import DataLoaderLuWT as DataLoader
self.undefined_node_types = ['Unknown']
elif data_origin.startswith("lutet2"):
from ncem.data import DataLoaderLuTET2 as DataLoader
self.undefined_node_types = ['Unknown']
elif data_origin == "10xvisium":
from ncem.data import DataLoader10xVisiumMouseBrain as DataLoader
self.undefined_node_types = None
if n_rings > 1:
coord_type = 'grid'
else:
n_rings = 1
coord_type = 'generic'
radius = 0
elif data_origin == "10xvisium_lymphnode":
from ncem.data import DataLoader10xLymphnode as DataLoader
self.undefined_node_types = None
if n_rings > 1:
coord_type = 'grid'
else:
n_rings = 1
coord_type = 'generic'
radius = 0
elif data_origin.startswith('destvi_lymphnode'):
self.targeted_assay = False
from ncem.data import DataLoaderDestViLymphnode as DataLoader
self.undefined_node_types = None
elif data_origin.startswith('destvi_mousebrain'):
self.targeted_assay = False
from ncem.data import DataLoaderDestViMousebrain as DataLoader
self.undefined_node_types = None
elif data_origin.startswith('cell2location_lymphnode'):
self.targeted_assay = False
from ncem.data import DataLoaderCell2locationLymphnode as DataLoader
self.undefined_node_types = None
elif data_origin == "salasiss":
from ncem.data import DataLoaderSalasIss as DataLoader
self.undefined_node_types = None
self.data = DataLoader(
data_path, radius=radius, coord_type=coord_type, n_rings=n_rings, label_selection=label_selection,
n_top_genes=n_top_genes
)
[docs] def get_data(
self,
data_origin: str,
data_path: str,
radius: Optional[int],
n_rings: int = 1,
graph_covar_selection: Optional[Union[List[str], Tuple[str]]] = None,
node_label_space_id: str = "type",
node_feature_space_id: str = "standard",
use_covar_node_position: bool = False,
use_covar_node_label: bool = False,
use_covar_graph_covar: bool = False,
domain_type: str = "image",
robustness: Optional[float] = None,
robustness_seed: int = 1,
n_top_genes: Optional[int] = None,
segmentation_robustness: Optional[List[float]] = None,
resimulate_nodes: bool = False,
resimulate_nodes_w_depdency: bool = False,
resimulate_nodes_sparsity_rate: float = 0.5,
):
"""Get data used in estimator classes.
Parameters
----------
data_origin : str
Data origin.
data_path : str
Data path.
radius : int, optional
Radius.
n_rings : int
Number of rings of neighbors for grid data.
graph_covar_selection : list, tuple, optional
Selected graph covariates.
node_label_space_id : str
Node label space id.
node_feature_space_id : str
Node feature space id.
use_covar_node_position : bool
Whether to use node position as covariate.
use_covar_node_label : bool
Whether to use node label as covariate.
use_covar_graph_covar : bool
Whether to use graph covariates.
domain_type : str
Covariate that is used as domain.
robustness : float, optional
Optional fraction of images for robustness test.
robustness_seed: int
Seed for robustness analysis
n_top_genes: int, optional
N top genes for highly variable gene selection.
segmentation_robustness: list, optional
Parameters for segmentation robustness fit, float for fraction of nodes and float for signal overflow.
Raises
------
ValueError
If sub-selected covar_selection could not be found, `node_label_space_id` or `node_feature_space_id`
not recognized
"""
if self.adj_type is None:
raise ValueError("set adj_type by init_estim() first")
if graph_covar_selection is None:
graph_covar_selection = []
labels_to_load = graph_covar_selection
self._load_data(
data_origin=data_origin,
data_path=data_path,
radius=radius,
n_rings=n_rings,
label_selection=labels_to_load,
n_top_genes=n_top_genes
)
if robustness:
np.random.seed(robustness_seed)
n_images = np.int(len(self.data.img_celldata) * robustness)
print(n_images)
image_keys = list(np.random.choice(
a=list(self.data.img_celldata.keys()),
size=n_images,
replace=False,
))
self.data.img_celldata = {k: self.data.img_celldata[k] for k in image_keys}
metadata = self.data.celldata.uns["metadata"]
self.data.celldata = self.data.celldata[self.data.celldata.obs[metadata['image_col']].isin(image_keys)]
print(
"\nAttention: Running robustness model with a fraction %f images, so [%i] images. \n"
"\nThis also adjusts celldata and img_celldata."
% (
robustness,
n_images,
)
)
if segmentation_robustness:
node_fraction = segmentation_robustness[0]
overflow_fraction = segmentation_robustness[1]
total_size = np.int(self.data.celldata.shape[0] * node_fraction)
for key, ad in self.data.img_celldata.items():
size = np.int(ad.shape[0] * node_fraction)
random_indices = np.random.choice(ad.shape[0], size=size, replace=False)
a = ad.obsp['adjacency_matrix_connectivities'].toarray()
err_ad = ad.copy()
for idx in random_indices:
adj = a[idx, :]
neigh_idx = np.random.choice(np.where(adj == 1.)[0], size=1, replace=False)
err_ad.X[idx, :] = ad.X[idx, :] + overflow_fraction * ad.X[neigh_idx, :]
err_ad.X[neigh_idx, :] = (1. - overflow_fraction) * ad.X[neigh_idx, :]
self.data.img_celldata[key] = err_ad
print(
"\nAttention: Running segmentation robustness model on %f of all nodes, so [%i] nodes. \n"
"\nSignal overflow is set to %f . This adjusts img_celldata, celldata remains unchanged.\n"
% (
node_fraction,
total_size,
overflow_fraction
)
)
self.simulation = False
if resimulate_nodes:
self.simulation = True
n_target_cell_types = 2 if resimulate_nodes_w_depdency else 1
dependencies_per_type = 1
# Create map from real cell types to simulated ones (can be coarser):
found_all_types = False
futile_counter = 0
node_type_map_idx = None
while not found_all_types:
node_type_map_idx = np.array([
np.random.randint(low=0, high=n_target_cell_types)
for _ in self.data.celldata.uns["node_type_names"].keys()
])
futile_counter += 1
if np.all([x in node_type_map_idx for x in range(n_target_cell_types)]):
found_all_types = True
if futile_counter > 100:
raise ValueError("did not manage to sample all target cell types")
node_type_names = dict([
(x, "sim_" + str(y))
for x, y in zip(self.data.celldata.uns["node_type_names"].keys(), node_type_map_idx)
])
self.data.celldata.uns["node_type_names"] = node_type_names
nfeatures = self.data.img_celldata[list(self.data.img_celldata.keys())[0]].n_vars
# Mean effect per simulated cell types:
effect_ct = np.random.uniform(low=0., high=10., size=(n_target_cell_types, nfeatures))
# Create dependency structure of cell types.
# Base line dependency structure with all dependencies as 0.
cov_ct = np.zeros((n_target_cell_types, n_target_cell_types))
if resimulate_nodes_w_depdency:
# Add dependencies_per_type for each cell type:
for i in range(n_target_cell_types):
# Sample desired dependencies from non-self cell types:
js = np.random.choice(a=[ii for ii in range(n_target_cell_types) if i != ii],
size=dependencies_per_type, replace=False)
cov_ct[i, js] = 1.
# Pairwise dependencies: Effect (self cell type, neighbor cell type, feature)
effect_neighbors = np.random.uniform(low=4., high=6.,
size=(n_target_cell_types, n_target_cell_types, nfeatures))
# Simulate sparse effects:
sparsity_rate = resimulate_nodes_sparsity_rate # fraction of zero effects
effect_neighbors[np.random.binomial(n=1, p=sparsity_rate, size=effect_neighbors.shape) == 1.] = 0.
else:
effect_neighbors = np.zeros((n_target_cell_types, n_target_cell_types, nfeatures))
sigma_sq = 1.
self._simulation_parameters = {
"effect_ct": effect_ct,
"cov_ct": cov_ct,
"effect_neighbors": effect_neighbors,
"sigma_sq": sigma_sq,
"adatas": {}
}
for key, ad in self.data.img_celldata.items():
adj = ad.obsp['adjacency_matrix_connectivities'].toarray()
sim_ad = ad.copy()
nobs = sim_ad.n_obs
# Assign all cells from old cell type sets to corresponding new cell types, assumes one hot encoding.
sim_ad.obsm["node_types"] = np.concatenate([
np.expand_dims(
np.max(sim_ad.obsm["node_types"][:, np.where(node_type_map_idx == i)[0]], axis=1),
axis=1
)
for i in range(n_target_cell_types)
], axis=1)
assert np.all(sim_ad.obsm["node_types"].sum(axis=1) == 1.)
# Simulate count matrix:
dmat_ct = sim_ad.obsm["node_types"]
loc_neighbors = np.zeros((nobs, nfeatures))
for i in range(nobs):
ct = np.where(dmat_ct[i, :] == 1.)[0]
ct = ct[0] # flatten list of length 1
dmat_neighhors_i = np.zeros((1, n_target_cell_types,))
if resimulate_nodes_w_depdency:
for j in np.where(np.asarray(adj[i, :]).flatten() > 0)[0]:
ct_j = np.where(dmat_ct[j, :] == 1.)[0]
ct_j = ct_j[0] # flatten list of length 1
dmat_neighhors_i[0, ct_j] = 1.
loc_neighbors[i, :] = np.matmul(dmat_neighhors_i, effect_neighbors[ct])[0]
loc = np.matmul(dmat_ct, effect_ct) + loc_neighbors
sim_ad.X = np.random.normal(loc=loc, scale=sigma_sq)
self.data.img_celldata[key] = sim_ad
# Record simulation:
self._simulation_parameters["adatas"][key] = {
"adj": adj,
"ct": sim_ad.obsm["node_types"],
"x": sim_ad.X,
}
print(
"\nAttention: Running simulation-based expression augmentation. \n"
"\nThis adjusts img_celldata, celldata remains unchanged.\n"
)
# Validate graph-wise covariate selection:
if len(graph_covar_selection) > 0:
if (
np.sum(
[
x not in self.data.celldata.uns["graph_covariates"]["label_selection"]
for x in graph_covar_selection
]
)
> 0
):
raise ValueError(
"could not find some sub-selected covar_selection %s in %s"
% (str(graph_covar_selection), str(self.data.celldata.uns["graph_covariates"]["label_selection"]))
)
self.img_to_patient_dict = self.data.celldata.uns["img_to_patient_dict"]
self.complete_img_keys = list(self.data.img_celldata.keys())
self.a = {k: adata.obsp["adjacency_matrix_connectivities"] for k, adata in self.data.img_celldata.items()}
if self.adj_type == "scaled":
self.a = self.data._transform_all_a(self.a)
if node_label_space_id == "standard":
self.h_0 = {k: adata.X for k, adata in self.data.img_celldata.items()}
elif node_label_space_id == "type":
self.h_0 = {k: adata.obsm["node_types"] for k, adata in self.data.img_celldata.items()}
elif node_label_space_id == 'proportions':
self.h_0 = {k: adata.obsm["proportions"] for k, adata in self.data.img_celldata.items()}
else:
raise ValueError("node_label_space_id %s not recognized" % node_label_space_id)
if node_feature_space_id == "standard":
self.h_1 = {k: adata.X for k, adata in self.data.img_celldata.items()}
elif node_feature_space_id == "type":
self.h_1 = {k: adata.obsm["node_types"] for k, adata in self.data.img_celldata.items()}
else:
raise ValueError("node_feature_space_id %s not recognized" % node_feature_space_id)
self.node_types = {k: adata.obsm["node_types"] for k, adata in self.data.img_celldata.items()}
self.node_type_names = self.data.celldata.uns["node_type_names"]
self.n_features_type = list(self.node_types.values())[0].shape[1]
self.n_features_standard = self.data.celldata.shape[1]
self.node_feature_names = list(self.data.celldata.var_names)
self.size_factors = self.data.size_factors()
# Add covariates:
# Add graph-level covariate information
self.covar_selection = graph_covar_selection
self.graph_covar_names = self.data.celldata.uns["graph_covariates"]["label_names"]
# Split loaded graph-wise covariates into labels (output, Y) and covariates / features (input, C)
if len(graph_covar_selection) > 0:
self.graph_covar = { # Single 1D array per observation: concatenate all covariates!
k: np.concatenate(
[adata.uns["graph_covariates"]["label_tensors"][kk] for kk in self.covar_selection], axis=0
)
for k, adata in self.data.img_celldata.items()
}
# Replace masked entries (np.nan) by zeros: (masking can be handled properly in output but not here):
for k, v in self.graph_covar.items():
if np.any(np.isnan(v)):
self.graph_covar[k][np.isnan(v)] = 0.0
else:
# Create empty covariate arrays:
self.graph_covar = {k: np.array([], ndmin=1) for k, adata in self.data.img_celldata.items()}
# Add node-level conditional information
self.node_covar = {k: np.empty((adata.shape[0], 0)) for k, adata in self.data.img_celldata.items()}
# Cell position in image:
if use_covar_node_position:
for k in self.complete_img_keys:
self.node_covar[k] = np.append(self.node_covar[k], self.data.img_celldata[k].obsm["spatial"], axis=1)
print("Position_matrix added to categorical predictor matrix")
# Add graph-level covariates to node covariates:
if use_covar_graph_covar:
for k in self.complete_img_keys:
# Broadcast graph-level covariate to nodes:
c = np.repeat(self.graph_covar[k][np.newaxis, :], self.data.img_celldata[k].shape[0], axis=0)
self.node_covar[k] = np.append(self.node_covar[k], c, axis=1)
print("Node_covar_selection broadcasted to categorical predictor matrix")
# Add node
if use_covar_node_label:
for k in self.complete_img_keys:
node_types = self.data.img_celldata[k].obsm["node_types"]
self.node_covar[k] = np.append(self.node_covar[k], node_types, axis=1)
print("Node_type added to categorical predictor matrix")
# Set selection-specific tensor dimensions:
self.n_features_0 = list(self.h_0.values())[0].shape[1]
self.n_features_1 = list(self.h_1.values())[0].shape[1]
self.n_graph_covariates = list(self.graph_covar.values())[0].shape[0]
self.n_node_covariates = list(self.node_covar.values())[0].shape[1]
self.max_nodes = max([self.a[i].shape[0] for i in self.complete_img_keys])
# Define domains
if domain_type == "image":
self.domains = {key: i for i, key in enumerate(self.complete_img_keys)}
elif domain_type == "patient":
self.domains = {
key: list(self.patient_ids_unique).index(self.img_to_patient_dict[key])
for i, key in enumerate(self.complete_img_keys)
}
else:
assert False
self.n_domains = len(np.unique(list(self.domains.values())))
if self.targeted_assay:
self.proportions = None
else:
self.proportions = {k: adata.obsm["proportions"] for k, adata in self.data.img_celldata.items()}
# Report summary statistics of loaded graph:
print(
"Mean of mean node degree per images across images: %f"
% np.mean([np.mean(v.sum(axis=1)) for k, v in self.a.items()])
)
@abc.abstractmethod
def _get_dataset(
self,
image_keys: List[str],
nodes_idx: Dict[str, np.ndarray],
batch_size: int,
shuffle_buffer_size: Optional[int],
train: bool,
seed: Optional[int],
prefetch: int = 100,
reinit_n_eval: Optional[int] = None,
):
"""Prepare a dataset.
Parameters
----------
image_keys : np.array
Image keys in partition.
nodes_idx : dict, str
Dictionary of nodes per image in partition.
batch_size : int
Batch size.
shuffle_buffer_size : int, optional
Shuffle buffer size.
train : bool
Whether dataset is used for training or not (influences shuffling of nodes).
seed : int, optional
Random seed.
prefetch: int
Prefetch of dataset.
reinit_n_eval : int, optional
Used if model is reinitialized to different number of nodes per graph.
"""
pass
@abc.abstractmethod
def _get_resampled_dataset(
self,
image_keys: np.ndarray,
nodes_idx: dict,
batch_size: int,
seed: Optional[int] = None,
prefetch: int = 100,
):
"""Evaluate model based on resampled dataset for posterior resampling.
node_1 + domain_1 -> encoder -> z_1 + domain_2 -> decoder -> reconstruction_2.
Parameters
----------
image_keys: np.array
Image keys in partition.
nodes_idx : dict
Dictionary of nodes per image in partition.
batch_size : int
Batch size.
seed : int, optional
Seed.
prefetch : int
Prefetch.
"""
pass
[docs] @abc.abstractmethod
def init_model(self, **kwargs):
"""Initialize and compiles the model.
Parameters
----------
kwargs
Arbitrary keyword arguments.
"""
pass
@property
def patient_ids_bytarget(self) -> np.ndarray:
"""Return patient identifiers by target.
Returns
-------
patient_ids_bytarget
"""
return np.array([self.img_to_patient_dict[x] for x in self.complete_img_keys])
@property
def patient_ids_unique(self) -> np.ndarray:
"""Return unique patient identifiers.
Returns
-------
patient_ids_unique
"""
return np.unique(self.patient_ids_bytarget)
@property
def img_keys_all(self):
"""Return all image keys.
Returns
-------
img_keys_all
"""
return np.unique(np.concatenate([self.img_keys_train, self.img_keys_eval, self.img_keys_test])).tolist()
@property
def nodes_idx_all(self):
"""Return all node indices.
Returns
-------
nodes_idx_all
"""
return dict(
[
(
x,
np.unique(
np.concatenate(
[
self.nodes_idx_train[x] if x in self.nodes_idx_train.keys() else np.array([]),
self.nodes_idx_eval[x] if x in self.nodes_idx_eval.keys() else np.array([]),
self.nodes_idx_test[x] if x in self.nodes_idx_test.keys() else np.array([]),
]
)
),
)
for x in list(self.img_keys_all)
]
)
@staticmethod
def _prepare_sf(x):
"""Prepare size factors.
Parameters
----------
x
Inout array.
Returns
-------
size_factors
Raises
------
ValueError
x.shape > 2
"""
if len(x.shape) == 2:
sf = np.asarray(x.sum(axis=1)).flatten()
elif len(x.shape) == 1:
sf = np.asarray(x.sum()).flatten()
else:
raise ValueError("x.shape > 2")
sf = sf / np.mean(sf)
return sf
def _compile_model(self, optimizer: tf.keras.optimizers.Optimizer, output_layer: str):
"""Compile all necessary models.
ATTENTION: Decoder compiled with same optimizer instance as training model if an instance is passed!
Parameters
----------
optimizer : tf.keras.optimizers.Optimizer
Optimizer to be used for training model (and decoder).
output_layer : str
Output layer to be used (e.g. gaussian).
Raises
------
ValueError
If `output_layer` is not recognized.
"""
self.vi_model = False # variational inference
if self.model_type in ["cvae", "cvae_ncem"]:
self.vi_model = True
enc_dec_model = self.model_type == "cvae" or self.model_type == "cvae_ncem"
if output_layer in ["gaussian", "gaussian_const_disp", "linear", "linear_const_disp"]:
reconstruction_loss = GaussianLoss()
reconstruction_metrics = [
custom_mae,
custom_mean_sd,
custom_mse,
custom_mse_scaled,
gaussian_reconstruction_loss,
r_squared,
r_squared_linreg,
]
elif output_layer == "nb" or output_layer == "nb_shared_disp" or output_layer == "nb_const_disp":
reconstruction_loss = NegBinLoss()
reconstruction_metrics = [
custom_mae,
custom_mean_sd,
nb_reconstruction_loss,
logp1_custom_mse,
logp1_r_squared,
logp1_r_squared_linreg,
]
else:
raise ValueError("output_layer %s not recognized" % output_layer)
self.output_layer = output_layer
if self.vi_model:
self.loss = [
reconstruction_loss,
KLLoss(beta=self.beta, max_beta=self.max_beta, pre_warm_up=self.pre_warm_up),
]
self.metrics = [reconstruction_metrics, [custom_kl]]
else:
self.loss = [reconstruction_loss]
self.metrics = reconstruction_metrics
self.model.training_model.compile(optimizer=optimizer, loss=self.loss, metrics=self.metrics)
# Also compile sampling model / decoder if available:
if enc_dec_model:
self.model.decoder.compile(optimizer=optimizer, loss=self.loss, metrics=self.metrics)
if self.vi_model:
self.model.decoder_sampling.compile(optimizer=optimizer, loss=self.loss, metrics=self.metrics)
def _remove_unidentified_nodes(self, node_idx) -> Tuple[int, dict]:
"""Exclude undefined cells from data set.
Parameters
----------
node_idx
Data set to remove unidentified nodes from.
Returns
-------
tuple
number of unidentifed nodes removed, data set with unidentified nodes removed from.
"""
if self.undefined_node_types is not None:
# Identify cells with undefined cell_type in all target images
undefined_label_idx = np.where([x in self.undefined_node_types for x in self.node_type_names.values()])[0]
# Extract shape and number for print statement
n_undefined_nodes = np.sum(
[
np.sum(np.any(self.node_types[k][v, :][:, undefined_label_idx] == 1, axis=1))
for k, v in node_idx.items()
]
)
node_idx.update(
{
k: v[
np.where(np.logical_not(np.any(self.node_types[k][v, :][:, undefined_label_idx] == 1, axis=1)))[
0
]
]
for k, v in node_idx.items()
}
)
else:
n_undefined_nodes = 0
return n_undefined_nodes, node_idx
[docs] def split_data_given(
self, img_keys_test, img_keys_train, img_keys_eval, nodes_idx_test, nodes_idx_train, nodes_idx_eval
):
"""Split data by given partition.
Parameters
----------
img_keys_test
Test image keys.
img_keys_train
Train image keys.
img_keys_eval
Evaluation image keys.
nodes_idx_test
Test node indices.
nodes_idx_train
Train node indices.
nodes_idx_eval
Evaluation node indices.
"""
self.img_keys_test = img_keys_test
self.img_keys_train = img_keys_train
self.img_keys_eval = img_keys_eval
self.nodes_idx_test = nodes_idx_test
self.nodes_idx_train = nodes_idx_train
self.nodes_idx_eval = nodes_idx_eval
[docs] def split_data_node(self, test_split: float, validation_split: float, seed: int = 1):
"""Split nodes randomly into partitions.
Parameters
----------
test_split : float
Fraction of total nodes to be in test set.
validation_split : float
Fraction of train-eval nodes to be in validation split.
seed : int
Seed for random selection of observations.
Raises
------
ValueError
If evaluation or test dataset are empty.
"""
print(
"Using split method: node. \n Train-test-validation split is based on total number of nodes "
"per patients over all images."
)
np.random.seed(seed)
h_nodes_dict = {a: b.shape[0] for a, b in self.h_0.items()}
all_nodes = sum(h_nodes_dict.values())
nodes_all_idx = {a: np.arange(0, b) for a, b in h_nodes_dict.items()}
self.img_keys_test = list(self.complete_img_keys)
self.img_keys_train = list(self.complete_img_keys)
self.img_keys_eval = list(self.complete_img_keys)
n_undefined_nodes, nodes_all_idx = self._remove_unidentified_nodes(node_idx=nodes_all_idx)
# updating h_nodes_dict to only include the number of identified cells
h_nodes_dict = {a: b.shape[0] for a, b in nodes_all_idx.items()}
# Do Test-Val-Train split by patients and put all images for a patient into the chosen partition:
if isinstance(test_split, str) and test_split == "one sample":
h_test_dict = {i: 1 for i, b in h_nodes_dict.items()}
else:
h_test_dict = {i: round(b * test_split) for i, b in h_nodes_dict.items()}
self.nodes_idx_test = {
k: np.random.choice(a=nodes_all_idx[k], size=h_test_dict[k], replace=False) for k in self.h_0.keys()
}
nodes_idx_test_shapes = {a: b.shape[0] for a, b in self.nodes_idx_test.items()}
test_nodes = sum(nodes_idx_test_shapes.values())
nodes_idx_train_eval = {
i: np.array([x for x in nodes_all_idx[i] if x not in self.nodes_idx_test[i]]) for i in self.h_0.keys()
}
self.nodes_idx_eval = {
i: np.random.choice(a=nodes_idx_train_eval[i], size=round(len(nodes_idx_train_eval[i]) * validation_split))
for i in self.h_0.keys()
}
nodes_idx_eval_shapes = {a: b.shape[0] for a, b in self.nodes_idx_eval.items()}
eval_nodes = sum(nodes_idx_eval_shapes.values())
self.nodes_idx_train = {
i: np.array([x for x in nodes_idx_train_eval[i] if x not in self.nodes_idx_eval[i]])
for i in self.h_0.keys()
}
nodes_idx_train_shapes = {a: b.shape[0] for a, b in self.nodes_idx_train.items()}
train_nodes = sum(nodes_idx_train_shapes.values())
print(
"\nExcluded %i cells with the following unannotated cell type: [%s] \n"
"\nWhole dataset: %i cells out of %i images from %i patients."
% (
n_undefined_nodes,
self.undefined_node_types,
all_nodes,
len(list(self.complete_img_keys)),
len(self.patient_ids_unique),
)
)
print(
"Test dataset: %i cells out of %i images from %i patients."
% (
test_nodes,
len(self.img_keys_test),
len(self.patient_ids_unique),
)
)
print(
"Training dataset: %i cells out of %i images from %i patients."
% (
train_nodes,
len(self.img_keys_train),
len(self.patient_ids_unique),
)
)
print(
"Validation dataset: %i cells out of %i images from %i patients. \n"
% (
eval_nodes,
len(self.img_keys_eval),
len(self.patient_ids_unique),
)
)
# Check that none of the train, eval partitions are empty
if not eval_nodes:
raise ValueError("The evaluation dataset is empty.")
if not train_nodes:
raise ValueError("The train dataset is empty.")
[docs] def split_data_target_cell(self, target_cell: str, test_split: float, validation_split: float, seed: int = 1):
"""Split nodes randomly into partitions.
Parameters
----------
target_cell : str
Target cell type.
test_split : float
Fraction of total nodes to be in test set.
validation_split : float
Fraction of train-eval nodes to be in validation split.
seed : int
Seed for random selection of observations.
Raises
------
ValueError
If evaluation or test dataset are empty.
"""
print(
"Using split method: target cell. \n Train-test-validation split is based on total number of nodes "
"per patients over all images."
)
np.random.seed(seed)
h_nodes_dict = {a: b.shape[0] for a, b in self.h_0.items()}
all_nodes = sum(h_nodes_dict.values())
nodes_all_idx = {a: np.arange(0, b) for a, b in h_nodes_dict.items()}
target_cell_id = list(self.node_type_names.values()).index(target_cell)
# Assign images to partitions:
self.img_keys_train = list(self.complete_img_keys)
self.img_keys_eval = self.img_keys_train.copy()
self.img_keys_test = self.img_keys_train.copy()
n_undefined_nodes, nodes_all_idx = self._remove_unidentified_nodes(node_idx=nodes_all_idx)
# Dictionary of all nodes within a target cell type
nodes_all_idx = {k: np.where(self.node_types[k][:, target_cell_id] == 1)[0] for k in self.img_keys_train}
# updating h_nodes_dict to only include the number of identified cells of specific target cell
h_nodes_dict = {a: b.shape[0] for a, b in nodes_all_idx.items()}
nodes_all_idx_shapes = {a: b.shape[0] for a, b in nodes_all_idx.items()}
target_cell_nodes = sum(nodes_all_idx_shapes.values())
# Do Test-Val-Train split by patients and put all images for a patient into the chosen partition:
if isinstance(test_split, str) and test_split == "one sample":
h_test_dict = {i: 1 for i, b in h_nodes_dict.items()}
else:
h_test_dict = {i: round(b * test_split) for i, b in h_nodes_dict.items()}
# Assign nodes to partitions:
# Test partition:
self.nodes_idx_test = {
k: np.random.choice(a=nodes_all_idx[k], size=h_test_dict[k], replace=False) for k in self.h_0.keys()
}
nodes_idx_test_shapes = {a: b.shape[0] for a, b in self.nodes_idx_test.items()}
test_nodes = sum(nodes_idx_test_shapes.values())
# Define train-eval partition:
nodes_idx_train_eval = {
i: np.array([x for x in nodes_all_idx[i] if x not in self.nodes_idx_test[i]]) for i in self.h_0.keys()
}
# Randomly partition train-eval into train and eval:
self.nodes_idx_eval = {
i: np.random.choice(a=nodes_idx_train_eval[i], size=round(len(nodes_idx_train_eval[i]) * validation_split))
for i in self.h_0.keys()
}
nodes_idx_eval_shapes = {a: b.shape[0] for a, b in self.nodes_idx_eval.items()}
eval_nodes = sum(nodes_idx_eval_shapes.values())
# Assign all nodes in train-eval that are not assigned to eval to train:
self.nodes_idx_train = {
i: np.array([x for x in nodes_idx_train_eval[i] if x not in self.nodes_idx_eval[i]])
for i in self.h_0.keys()
}
nodes_idx_train_shapes = {a: b.shape[0] for a, b in self.nodes_idx_train.items()}
train_nodes = sum(nodes_idx_train_shapes.values())
print(
"\nExcluded %i cells with the following unannotated cell type: [%s] \n"
"\nWhole dataset: %i cells out of %i images from %i patients."
% (
n_undefined_nodes,
self.undefined_node_types,
all_nodes,
len(list(self.complete_img_keys)),
len(self.patient_ids_unique),
)
)
print(
"\nCell type used for training %s: %i cells out of %i images from %i patients. "
% (
target_cell,
target_cell_nodes,
len(list(self.complete_img_keys)),
len(self.patient_ids_unique),
)
)
print(
"Test dataset: %i cells out of %i images from %i patients. "
% (
test_nodes,
len(self.img_keys_test),
len(self.patient_ids_unique),
)
)
print(
"Training dataset: %i cells out of %i images from %i patients."
% (
train_nodes,
len(self.img_keys_train),
len(self.patient_ids_unique),
)
)
print(
"Validation dataset: %i cells out of %i images from %i patients.\n"
% (
eval_nodes,
len(self.img_keys_eval),
len(self.patient_ids_unique),
)
)
# Check that none of the train, eval partitions are empty
if not eval_nodes:
raise ValueError("The evaluation dataset is empty.")
if not train_nodes:
raise ValueError("The train dataset is empty.")
[docs] def train(
self,
epochs: int = 1000,
epochs_warmup: int = 0,
max_steps_per_epoch: Optional[int] = 20,
batch_size: int = 16,
validation_batch_size: int = 16,
max_validation_steps: Optional[int] = 10,
shuffle_buffer_size: Optional[int] = int(1e4),
patience: int = 20,
lr_schedule_min_lr: float = 1e-5,
lr_schedule_factor: float = 0.2,
lr_schedule_patience: int = 5,
initial_epoch: int = 0,
monitor_partition: str = "val",
monitor_metric: str = "loss",
log_dir: Optional[str] = None,
callbacks: Optional[list] = None,
early_stopping: bool = True,
reduce_lr_plateau: bool = True,
pretrain_decoder: bool = False,
decoder_epochs: int = 1000,
decoder_patience: int = 20,
decoder_callbacks: Optional[list] = None,
aggressive: bool = False,
aggressive_enc_patience: int = 10,
aggressive_epochs: int = 5,
seed: int = 1234,
**kwargs,
):
"""Train model.
Use validation loss and maximum number of epochs as termination criteria.
Parameters
----------
epochs : int
Integer number of times to iterate over the training data arrays. If unspecified, it will default to 1000.
epochs_warmup : int
Integer number of times to iterate over the training data arrays in warm up (without early stopping). If
unspecified, it will default to 0.
max_steps_per_epoch : int, optional
Maximal steps per epoch. If unspecified, it will default to 20.
batch_size : int
Number of samples per gradient update. If unspecified, it will default to 16.
validation_batch_size : int
Number of samples in validation. If unspecified, it will default to 16.
max_validation_steps : int
Maximal steps per validation. If unspecified, it will default to 10.
shuffle_buffer_size : int, optional
Shuffle buffer size. If unspecified, it will default to 1e4.
patience : int
Number of epochs with no improvement. If unspecified, it will default to 20.
lr_schedule_min_lr : float
Lower bound on the learning rate. If unspecified, it will default to 1e-5.
lr_schedule_factor : float
Factor by which the learning rate will be reduced. new_lr = lr * factor. If unspecified, it will default
to 0.2.
lr_schedule_patience : int
Number of epochs with no improvement after which learning rate will be reduced. If unspecified, it will
default to 5.
initial_epoch : int
Epoch at which to start training (useful for resuming a previous training run). If unspecified, it will
default to 0.
monitor_partition : str
Monitor partition.
monitor_metric : str
Monitor metric.
log_dir : str, optional
Logging directory.
callbacks : list, optional
List of callbacks to be called during training.
early_stopping : bool
Whether to activate early stopping.
reduce_lr_plateau : bool
Whether to reduce learning rate on plateau.
pretrain_decoder : bool
Whether to pretrain the decoder model.
decoder_epochs : int
Integer number of times to iterate over the training data arrays in decoder pretraining. If unspecified, it
will default to 1000.
decoder_patience : int
Number of epochs with no improvement in decoder pretraining. If unspecified, it will default to 20.
decoder_callbacks : list, optional
List of callbacks to be called during decoder pretraining.
aggressive : bool
Whether to train aggressive.
aggressive_enc_patience : int
Number of epochs with no improvement in aggressive training. If unspecified, it will default to 10.
aggressive_epochs : int
Integer number of times to iterate over the training data arrays in aggressive training. If unspecified, it
will default to 5.
seed : int
Random seed for reproduability.
kwargs
Arbitrary keyword arguments.
"""
# Save training settings to allow model restoring.
self.train_hyperparam = {
"epochs": epochs,
"epochs_warmup": epochs_warmup,
"max_steps_per_epoch": max_steps_per_epoch,
"batch_size": batch_size,
"validation_batch_size": validation_batch_size,
"max_validation_steps": max_validation_steps,
"shuffle_buffer_size": shuffle_buffer_size,
"patience": patience,
"lr_schedule_min_lr": lr_schedule_min_lr,
"lr_schedule_factor": lr_schedule_factor,
"lr_schedule_patience": lr_schedule_patience,
"log_dir": log_dir,
"pretrain_decoder": pretrain_decoder,
"decoder_epochs": decoder_epochs,
"decoder_patience": decoder_patience,
"aggressive": aggressive,
"aggressive_enc_patience": aggressive_enc_patience,
"aggressive_epochs": aggressive_epochs,
}
self.train_dataset = self._get_dataset(
image_keys=self.img_keys_train,
nodes_idx=self.nodes_idx_train,
batch_size=batch_size,
shuffle_buffer_size=shuffle_buffer_size,
train=True,
seed=seed,
reinit_n_eval=None,
)
self.eval_dataset = self._get_dataset(
image_keys=self.img_keys_eval,
nodes_idx=self.nodes_idx_eval,
batch_size=validation_batch_size,
shuffle_buffer_size=shuffle_buffer_size,
train=True,
seed=seed,
reinit_n_eval=None,
)
self.steps_per_epoch = min(max(len(self.img_keys_train) // batch_size, 1), max_steps_per_epoch)
self.validation_steps = min(max(len(self.img_keys_eval) // validation_batch_size, 1), max_validation_steps)
if pretrain_decoder:
self.pretrain_decoder(
decoder_epochs=decoder_epochs,
patience=decoder_patience,
lr_schedule_min_lr=lr_schedule_min_lr,
lr_schedule_factor=lr_schedule_factor,
lr_schedule_patience=lr_schedule_patience,
initial_epoch=initial_epoch,
monitor_partition=monitor_partition,
monitor_metric=monitor_metric,
log_dir=log_dir,
callbacks=decoder_callbacks,
early_stopping=early_stopping,
reduce_lr_plateau=reduce_lr_plateau,
**kwargs,
)
if aggressive:
self.train_aggressive(aggressive_enc_patience=aggressive_enc_patience, aggressive_epochs=aggressive_epochs)
if epochs_warmup > 0:
self.train_normal(
epochs=epochs_warmup,
patience=patience,
lr_schedule_min_lr=lr_schedule_min_lr,
lr_schedule_factor=lr_schedule_factor,
lr_schedule_patience=int(10000), # dont reduce
initial_epoch=initial_epoch,
monitor_partition=monitor_partition,
monitor_metric=monitor_metric,
log_dir=log_dir,
callbacks=callbacks,
early_stopping=False,
reduce_lr_plateau=reduce_lr_plateau,
**kwargs,
)
initial_epoch += epochs_warmup
self.train_normal(
epochs=epochs,
patience=patience,
lr_schedule_min_lr=lr_schedule_min_lr,
lr_schedule_factor=lr_schedule_factor,
lr_schedule_patience=lr_schedule_patience,
initial_epoch=initial_epoch,
monitor_partition=monitor_partition,
monitor_metric=monitor_metric,
log_dir=log_dir,
callbacks=callbacks,
early_stopping=early_stopping,
reduce_lr_plateau=reduce_lr_plateau,
**kwargs,
)
[docs] def train_normal(
self,
epochs: int = 1000,
patience: int = 20,
lr_schedule_min_lr: float = 1e-5,
lr_schedule_factor: float = 0.2,
lr_schedule_patience: int = 5,
initial_epoch: int = 0,
monitor_partition: str = "val",
monitor_metric: str = "loss",
log_dir: Optional[str] = None,
callbacks: Optional[list] = None,
early_stopping: bool = True,
reduce_lr_plateau: bool = True,
**kwargs,
):
"""Train model normal.
Use validation loss and maximum number of epochs as termination criteria.
Parameters
----------
epochs : int
Integer number of times to iterate over the training data arrays. If unspecified, it will default to 1000.
patience : int
Number of epochs with no improvement. If unspecified, it will default to 20.
lr_schedule_min_lr : float
Lower bound on the learning rate. If unspecified, it will default to 1e-5.
lr_schedule_factor : float
Factor by which the learning rate will be reduced. new_lr = lr * factor. If unspecified, it will default
to 0.2.
lr_schedule_patience : int
Number of epochs with no improvement after which learning rate will be reduced. If unspecified, it will
default to 5.
initial_epoch : int
Epoch at which to start training (useful for resuming a previous training run). If unspecified, it will
default to 0.
monitor_partition : str
Monitor partition.
monitor_metric : str
Monitor metric.
log_dir : str, optional
Logging directory.
callbacks : list, optional
List of callbacks to be called during training.
early_stopping : bool
Whether to activate early stopping.
reduce_lr_plateau : bool
Whether to reduce learning rate on plateau.
kwargs
Arbitrary keyword arguments.
"""
# Set callbacks.
cbs = []
if reduce_lr_plateau:
cbs.append(
tf.keras.callbacks.ReduceLROnPlateau(
monitor=monitor_partition + "_" + monitor_metric,
factor=lr_schedule_factor,
patience=lr_schedule_patience,
min_lr=lr_schedule_min_lr,
)
)
if early_stopping:
cbs.append(
tf.keras.callbacks.EarlyStopping(
monitor=monitor_partition + "_" + monitor_metric, patience=patience, restore_best_weights=True
)
)
if log_dir is not None:
cbs.append(
tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1,
write_graph=False,
write_grads=False,
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
embeddings_data=None,
update_freq="epoch",
)
)
if callbacks is not None:
# callbacks needs to be a list
cbs += callbacks
history = self.model.training_model.fit(
x=self.train_dataset,
epochs=epochs,
initial_epoch=initial_epoch,
steps_per_epoch=self.steps_per_epoch,
callbacks=cbs,
validation_data=self.eval_dataset,
validation_steps=self.validation_steps,
verbose=2,
**kwargs,
).history
for k, v in history.items(): # append to history if train() has been called before.
if k in self.history.keys():
self.history[k].extend(v)
else:
self.history[k] = v
[docs] def pretrain_decoder(
self,
decoder_epochs: int = 1000,
patience: int = 20,
lr_schedule_min_lr: float = 1e-5,
lr_schedule_factor: float = 0.2,
lr_schedule_patience: int = 5,
initial_epoch: int = 0,
monitor_partition: str = "val",
monitor_metric: str = "loss",
log_dir: Optional[str] = None,
callbacks: Optional[list] = None,
early_stopping: bool = True,
reduce_lr_plateau: bool = True,
**kwargs,
):
"""Pre-train decoder model.
Use validation loss and maximum number of epochs as termination criteria.
Parameters
----------
patience : int
Number of epochs with no improvement. If unspecified, it will default to 20.
lr_schedule_min_lr : float
Lower bound on the learning rate. If unspecified, it will default to 1e-5.
lr_schedule_factor : float
Factor by which the learning rate will be reduced. new_lr = lr * factor. If unspecified, it will default
to 0.2.
lr_schedule_patience : int
Number of epochs with no improvement after which learning rate will be reduced. If unspecified, it will
default to 5.
initial_epoch : int
Epoch at which to start training (useful for resuming a previous training run). If unspecified, it will
default to 0.
monitor_partition : str
Monitor partition.
monitor_metric : str
Monitor metric.
log_dir : str, optional
Logging directory.
callbacks : list, optional
List of callbacks to be called during training.
early_stopping : bool
Whether to activate early stopping.
reduce_lr_plateau : bool
Whether to reduce learning rate on plateau.
decoder_epochs : int
Integer number of times to iterate over the training data arrays in decoder pretraining. If unspecified, it
will default to 1000.
kwargs
Arbitrary keyword arguments.
"""
# Set callbacks.
cbs = []
if reduce_lr_plateau:
cbs.append(
tf.keras.callbacks.ReduceLROnPlateau(
monitor=monitor_partition + "_" + monitor_metric,
factor=lr_schedule_factor,
patience=lr_schedule_patience,
min_lr=lr_schedule_min_lr,
)
)
if early_stopping:
cbs.append(
tf.keras.callbacks.EarlyStopping(
monitor=monitor_partition + "_" + monitor_metric, patience=patience, restore_best_weights=True
)
)
if log_dir is not None:
cbs.append(
tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1,
write_graph=False,
write_grads=False,
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
embeddings_data=None,
update_freq="epoch",
)
)
if callbacks is not None:
# callbacks needs to be a list
cbs += callbacks
history = self.model.decoder_sampling.fit(
x=self.train_dataset,
epochs=decoder_epochs,
initial_epoch=initial_epoch,
steps_per_epoch=self.steps_per_epoch,
callbacks=cbs,
validation_data=self.eval_dataset,
validation_steps=self.validation_steps,
verbose=2,
**kwargs,
).history
for k, v in history.items(): # append to history if train() has been called before.
if k in self.history.keys():
self.history[k].extend(v)
else:
self.history[k] = v
# Transfer weights:
layer_names_training_model = [x.name for x in self.model.training_model.layers]
layer_names_decoder_model = [x.name for x in self.model.decoder_sampling.layers]
layers_updated = []
for x in layer_names_decoder_model:
w = self.model.decoder_sampling.get_layer(name=x).get_weights()
if x in layer_names_training_model:
# Only update layers with parameters:
if len(w) > 0:
self.model.training_model.get_layer(x).set_weights(w)
layers_updated.append(x)
elif "Output_sampling" in x:
# Find output layer *Output_decoder matched to *Output_sampling:
x_out = [y for y in layer_names_training_model if "Output_decoder" in y][0]
self.model.training_model.get_layer(x_out).set_weights(w)
layers_updated.append(x_out)
print(f"updated layers: {layers_updated}")
[docs] def train_aggressive(
self,
aggressive_enc_patience: int = 10,
aggressive_epochs: int = 5,
):
"""Train model aggressive.
Parameters
----------
aggressive_enc_patience : int
Number of epochs with no improvement in aggressive training. If unspecified, it will default to 10.
aggressive_epochs : int
Integer number of times to iterate over the training data arrays in aggressive training. If unspecified, it
will default to 5.
"""
# @tf.function
def train_iter(
x_batch_aggressive,
y_batch_aggressive,
train_dec,
train_enc,
):
with tf.GradientTape() as g:
output_decoder_concat, latent_space = self.model.training_model(x_batch_aggressive)
losses_aggressive = {
"reconstruction_loss": self.loss[0](y_batch_aggressive[0], output_decoder_concat),
"bottleneck_loss": self.loss[1](y_batch_aggressive[1], latent_space),
}
losses_aggressive["loss"] = (
losses_aggressive["reconstruction_loss"] + losses_aggressive["bottleneck_loss"]
)
if train_enc:
grad_enc = g.gradient(
target=losses_aggressive["loss"], sources=self.model.encoder_model.trainable_variables
)
self.optimizer.apply_gradients(zip(grad_enc, self.model.encoder_model.trainable_variables))
if train_dec:
grad_dec = g.gradient(
target=losses_aggressive["loss"], sources=self.model.decoder_model.trainable_variables
)
self.optimizer.apply_gradients(zip(grad_dec, self.model.decoder_model.trainable_variables))
metrics_values_output = {
"reconstruction_" + metric.__name__: metric(y_batch_aggressive[0], output_decoder_concat)
for metric in self.metrics[0]
}
metrics_values_latent = {
"bottleneck_" + metric.__name__: metric(y_batch_aggressive[1], latent_space)
for metric in self.metrics[1]
}
losses_aggressive.update(metrics_values_output)
losses_aggressive.update(metrics_values_latent)
# Add non-scaled ELBO to model as metric (ie no annealing or beta-VAE scaling):
log2pi = tf.math.log(2.0 * np.pi)
z, z_mean, z_log_var = tf.split(latent_space, num_or_size_splits=3, axis=1)
logqz_x = -0.5 * tf.reduce_mean(tf.square(z - z_mean) * tf.exp(-z_log_var) + z_log_var + log2pi)
logpz = -0.5 * tf.reduce_mean(tf.square(z) + log2pi)
d_kl = logqz_x - logpz
loc, scale = tf.split(output_decoder_concat, num_or_size_splits=2, axis=2)
input_x = x_batch_aggressive[0]
if self.output_layer == "gaussian" or self.output_layer == "gaussian_const_disp":
neg_ll = tf.math.log(tf.sqrt(2 * np.math.pi) * scale) + 0.5 * tf.math.square(
loc - input_x
) / tf.math.square(scale)
elif (
self.output_layer == "nb"
or self.output_layer == "nb_const_disp"
or self.output_layer == "nb_shared_disp"
):
eta_loc = tf.math.log(loc)
eta_scale = tf.math.log(scale)
log_r_plus_mu = tf.math.log(scale + loc)
ll = tf.math.lgamma(scale + input_x)
ll = ll - tf.math.lgamma(input_x + tf.ones_like(input_x))
ll = ll - tf.math.lgamma(scale)
ll = ll + tf.multiply(input_x, eta_loc - log_r_plus_mu) + tf.multiply(scale, eta_scale - log_r_plus_mu)
neg_ll = -tf.clip_by_value(ll, -300, 300, "log_probs")
else:
neg_ll = None
neg_ll = tf.reduce_mean(tf.reduce_sum(neg_ll, axis=-1))
losses_aggressive["elbo"] = neg_ll + d_kl
return losses_aggressive
aggressive = True
history = {}
ep = 0
while aggressive:
ep += 1
print("Epoch (aggressive) {}/{} - ".format(ep, aggressive_epochs), end="")
start = time.time()
no_improvement = 0
best_result = None
# inner loop training only encoder until no further improvement in ELBO val loss
enc_updates = 0
count = 0
while no_improvement < aggressive_enc_patience:
enc_updates += 1
for step, (x_batch, y_batch) in enumerate(self.train_dataset):
count = step
if step >= self.steps_per_epoch:
break
_ = train_iter(
x_batch_aggressive=x_batch,
y_batch_aggressive=y_batch,
train_enc=True,
train_dec=False,
)
elbo_enc_epoch_eval = 0
for step, (x_batch, y_batch) in enumerate(self.eval_dataset):
count = step
if step >= self.validation_steps:
break
losses = train_iter(
x_batch_aggressive=x_batch,
y_batch_aggressive=y_batch,
train_enc=False,
train_dec=False,
)
elbo_enc_epoch_eval += losses["elbo"]
elbo_enc_epoch_eval /= count
if best_result is None:
best_result = elbo_enc_epoch_eval
elif elbo_enc_epoch_eval < best_result:
best_result = elbo_enc_epoch_eval
no_improvement = 0
else:
no_improvement += 1
print("Performed %d encoder updates" % enc_updates)
# one step decoder training
hist = {}
for step, (x_batch, y_batch) in enumerate(self.train_dataset):
count = step
if step >= self.steps_per_epoch:
break
losses = train_iter(
x_batch_aggressive=x_batch,
y_batch_aggressive=y_batch,
train_enc=False,
train_dec=True,
)
for k, v in losses.items():
if k in hist.keys():
hist[k] += np.mean(v)
else:
hist[k] = np.mean(v)
hist = {key: value / count for key, value in hist.items()}
for k, v in hist.items():
if k in history.keys():
history[k].append(v)
else:
history[k] = [v]
hist_eval = {}
for step, (x_batch, y_batch) in enumerate(self.eval_dataset):
count = step
if step >= self.validation_steps:
break
losses = train_iter(
x_batch_aggressive=x_batch,
y_batch_aggressive=y_batch,
train_enc=False,
train_dec=False,
)
for k, v in losses.items():
if k in hist_eval.keys():
hist_eval[k] += np.mean(v)
else:
hist_eval[k] = np.mean(v)
hist_eval = {"val_" + key: value / count for key, value in hist_eval.items()}
for k, v in hist_eval.items():
if k in history.keys():
history[k].append(v)
else:
history[k] = [v]
if "lr" in history.keys():
history["lr"].append(self.optimizer.lr.numpy())
else:
history["lr"] = [self.optimizer.lr.numpy()]
if len(history["loss"]) >= aggressive_epochs:
aggressive = False
print("%d/%d - %ds" % (self.steps_per_epoch, self.steps_per_epoch, time.time() - start), end="")
for key, loss in history.items():
print(" - %s: %f" % (key, loss[-1]), end="")
print()
for k, v in history.items(): # append to history if train() has been called before.
if k in self.history.keys():
self.history[k].extend(v)
else:
self.history[k] = v
def _get_dataset_test(self, batch_size: int = 1):
"""Get test dataset.
Parameters
----------
batch_size : int
Number of samples. If unspecified, it will default to 1.
Returns
-------
A tensorflow dataset.
"""
if self.img_keys_test is not None and len(self.img_keys_test) != 0:
image_keys = self.img_keys_test
else:
warnings.warn("Image keys for test set empty. Evaluating on all images in whole dataset!")
image_keys = list(self.complete_img_keys)
if self.nodes_idx_test is not None:
nodes_idx = self.nodes_idx_test
else:
warnings.warn("Node idx for test set empty. Evaluating on all nodes in whole dataset!")
nodes_idx = "all"
return self._get_dataset(
image_keys=image_keys,
nodes_idx=nodes_idx,
batch_size=batch_size,
shuffle_buffer_size=1,
train=False,
seed=None,
reinit_n_eval=None,
)
[docs] def predict(self, batch_size: int = 1) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
"""Return observed labels and full predictions (including scale model) grouped exactly as in nodes_idx_test.
Parameters
----------
batch_size : int
Number of samples. If unspecified, it will default to 1.
Returns
-------
predict
"""
ds = self._get_dataset_test(batch_size=batch_size)
return self.model.training_model.predict(ds)
[docs] def evaluate_any(self, img_keys, node_idx, batch_size: int = 1):
"""Evaluate model on any given data set.
Parameters
----------
img_keys
Image keys.
node_idx
Nodes indices.
batch_size : int
Number of samples. If unspecified, it will default to 1.
Returns
-------
eval_dict
"""
ds = self._get_dataset(
image_keys=img_keys,
nodes_idx=node_idx,
batch_size=batch_size,
shuffle_buffer_size=1,
train=False,
seed=None,
reinit_n_eval=None,
)
results = self.model.training_model.evaluate(ds, verbose=False)
eval_dict = dict(zip(self.model.training_model.metrics_names, results))
return eval_dict
[docs] def evaluate_per_node_type(self, batch_size: int = 1):
"""Evaluate model for each node type seperately.
Parameters
----------
batch_size : int
Number of samples. If unspecified, it will default to 1.
Returns
-------
split_per_node_type, evaluation_per_node_type
"""
if self.simulation:
split_per_node_type = None
evaluation_per_node_type = None
else:
evaluation_per_node_type = {}
split_per_node_type = {}
node_types = list(self.node_type_names.keys())
for nt in node_types:
img_keys = list(self.complete_img_keys)
nodes_idx = {k: np.where(self.node_types[k][:, node_types.index(nt)] == 1)[0] for k in img_keys}
split_per_node_type.update({nt: {"img_keys": img_keys, "nodes_idx": nodes_idx}})
test = {k: len(np.where(self.node_types[k][:, node_types.index(nt)] == 1)[0]) for k in img_keys}
print("Evaluation for %s with %i cells" % (nt, sum(test.values())))
ds = self._get_dataset(
image_keys=img_keys,
nodes_idx=nodes_idx,
batch_size=batch_size,
shuffle_buffer_size=1,
train=False,
seed=None,
reinit_n_eval=None,
)
results = self.model.training_model.evaluate(ds, verbose=False)
eval_dict = dict(zip(self.model.training_model.metrics_names, results))
print(eval_dict)
evaluation_per_node_type.update({nt: eval_dict})
return split_per_node_type, evaluation_per_node_type
[docs]class EstimatorGraph(Estimator):
"""EstimatorGraph class for spatial models."""
[docs] def init_model(self, **kwargs):
"""Initialize EstimatorGraph.
Parameters
----------
kwargs
Arbitrary keyword arguments.
"""
pass
def _get_output_signature(self, resampled: bool = False):
"""Get output signatures.
Parameters
----------
resampled : bool
Whether dataset is resampled or not.
Returns
-------
output_signature
"""
h_1 = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_features_1), dtype=tf.float32
) # input node features
sf = tf.TensorSpec(shape=(self.n_eval_nodes_per_graph, 1), dtype=tf.float32) # input node size factors
h_0 = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_features_0), dtype=tf.float32
) # input node features conditional
h_0_full = tf.TensorSpec(
shape=(self.max_nodes, self.n_features_0), dtype=tf.float32
) # input node features conditional
a = tf.SparseTensorSpec(shape=None, dtype=tf.float32) # adjacency matrix
a_full = tf.SparseTensorSpec(shape=None, dtype=tf.float32) # adjacency matrix
node_covar = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_node_covariates), dtype=tf.float32
) # node-level covariates
domain = tf.TensorSpec(shape=(self.n_domains,), dtype=tf.int32) # domain
reconstruction = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_features_1), dtype=tf.float32
) # node features to reconstruct
kl_dummy = tf.TensorSpec(shape=(self.n_eval_nodes_per_graph,), dtype=tf.float32) # dummy for kl loss
if self.vi_model:
if resampled:
output_signature = (
(h_1, sf, h_0, h_0_full, a, a_full, node_covar, domain),
(reconstruction, kl_dummy),
(h_1, sf, h_0, h_0_full, a, a_full, node_covar, domain),
(reconstruction, kl_dummy),
)
else:
output_signature = ((h_1, sf, h_0, h_0_full, a, a_full, node_covar, domain), (reconstruction, kl_dummy))
else:
if resampled:
output_signature = (
(h_1, sf, h_0, h_0_full, a, a_full, node_covar, domain),
reconstruction,
(h_1, sf, h_0, h_0_full, a, a_full, node_covar, domain),
reconstruction,
)
else:
output_signature = ((h_1, sf, h_0, h_0_full, a, a_full, node_covar, domain), reconstruction)
return output_signature
def _get_dataset(
self,
image_keys: List[str],
nodes_idx: Dict[str, np.ndarray],
batch_size: int,
shuffle_buffer_size: Optional[int],
train: bool = True,
seed: Optional[int] = None,
prefetch: int = 100,
reinit_n_eval: Optional[int] = None,
):
"""Prepare a dataset.
Parameters
----------
image_keys : np.array
Image keys in partition.
nodes_idx : dict, str
Dictionary of nodes per image in partition.
batch_size : int
Batch size.
shuffle_buffer_size : int, optional
Shuffle buffer size.
train : bool
Whether dataset is used for training or not (influences shuffling of nodes).
seed : int, optional
Random seed.
prefetch: int
Prefetch of dataset.
reinit_n_eval : int, optional
Used if model is reinitialized to different number of nodes per graph.
Returns
-------
A tensorflow dataset.
"""
np.random.seed(seed)
if reinit_n_eval is not None and reinit_n_eval != self.n_eval_nodes_per_graph:
print(
"ATTENTION: specifying reinit_n_eval will change class argument n_eval_nodes_per_graph "
"from %i to %i" % (self.n_eval_nodes_per_graph, reinit_n_eval)
)
self.n_eval_nodes_per_graph = reinit_n_eval
def generator():
for key in image_keys:
if nodes_idx[key].size == 0: # needed for images where no nodes are selected
continue
idx_nodes = np.arange(0, self.a[key].shape[0])
if train:
index_list = [
np.asarray(
np.random.choice(
a=nodes_idx[key],
size=self.n_eval_nodes_per_graph,
replace=True,
),
dtype=np.int32,
)
]
else:
# dropping
index_list = [
np.asarray(
nodes_idx[key][self.n_eval_nodes_per_graph * i : self.n_eval_nodes_per_graph * (i + 1)],
dtype=np.int32,
)
for i in range(len(nodes_idx[key]) // self.n_eval_nodes_per_graph)
]
for indices in index_list:
h_0 = self.h_0[key][idx_nodes]
diff = self.max_nodes - h_0.shape[0]
zeros = np.zeros((diff, h_0.shape[1]), dtype="float32")
h_0_full = np.asarray(np.concatenate((h_0, zeros), axis=0), dtype="float32")
h_0 = h_0_full[indices]
h_1 = self.h_1[key][idx_nodes]
diff = self.max_nodes - h_1.shape[0]
zeros = np.zeros((diff, h_1.shape[1]), dtype="float32")
h_1 = np.asarray(np.concatenate((h_1, zeros), axis=0), dtype="float32")
h_1 = h_1[indices]
if self.log_transform:
h_1 = np.log(h_1 + 1.0)
# indexing adjacency matrix that yield only selected cells (final graph layer)
a = self.a[key][idx_nodes, :][:, idx_nodes]
a = a[indices, :]
coo = a.tocoo()
a_ind = np.asarray(np.mat([coo.row, coo.col]).transpose(), dtype="int64")
a_val = np.asarray(coo.data, dtype="float32")
a_shape = np.asarray((self.n_eval_nodes_per_graph, self.max_nodes), dtype="int64")
a = tf.SparseTensor(indices=a_ind, values=a_val, dense_shape=a_shape)
# propagating adjacency matrix that yield all cells (before final graph layer)
if self.cond_depth > 1:
a_full = self.a[key][idx_nodes, :][:, idx_nodes]
a_full = a_full.tocoo()
afull_ind = np.asarray(np.mat([a_full.row, a_full.col]).transpose(), dtype="int64")
afull_val = np.asarray(a_full.data, dtype="float32")
else:
afull_ind = np.asarray(np.mat([np.zeros((0,)), np.zeros((0,))]).transpose(), dtype="int64")
afull_val = np.asarray(np.zeros((0,)), dtype="float32")
afull_shape = np.asarray((self.max_nodes, self.max_nodes), dtype="int64")
a_full = tf.SparseTensor(indices=afull_ind, values=afull_val, dense_shape=afull_shape)
node_covar = self.node_covar[key][idx_nodes]
diff = self.max_nodes - node_covar.shape[0]
zeros = np.zeros((diff, node_covar.shape[1]))
node_covar = np.asarray(np.concatenate([node_covar, zeros], axis=0), dtype="float32")
node_covar = node_covar[indices, :]
sf = np.expand_dims(self.size_factors[key][idx_nodes], axis=1)
diff = self.max_nodes - sf.shape[0]
zeros = np.zeros((diff, sf.shape[1]))
sf = np.asarray(np.concatenate([sf, zeros], axis=0), dtype="float32")
sf = sf[indices, :]
g = np.zeros((self.n_domains,), dtype="int32")
g[self.domains[key]] = 1
if self.vi_model:
kl_dummy = np.zeros((self.n_eval_nodes_per_graph,), dtype="float32")
yield (h_1, sf, h_0, h_0_full, a, a_full, node_covar, g), (h_1, kl_dummy)
else:
yield (h_1, sf, h_0, h_0_full, a, a_full, node_covar, g), h_1
output_signature = self._get_output_signature(resampled=False)
dataset = tf.data.Dataset.from_generator(generator=generator, output_signature=output_signature)
if train:
if shuffle_buffer_size is not None:
dataset = dataset.shuffle(buffer_size=shuffle_buffer_size, seed=None, reshuffle_each_iteration=True)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(prefetch)
return dataset
def _get_resampled_dataset(
self,
image_keys: np.ndarray,
nodes_idx: dict,
batch_size: int,
seed: Optional[int] = None,
prefetch: int = 100,
reinit_n_eval: Optional[int] = None,
):
"""Evaluate model based on resampled dataset for posterior resampling.
node_1 + domain_1 -> encoder -> z_1 + domain_2 -> decoder -> reconstruction_2.
Parameters
----------
image_keys: np.array
Image keys in partition.
nodes_idx : dict
Dictionary of nodes per image in partition.
batch_size : int
Batch size.
seed : int, optional
Seed.
prefetch : int
Prefetch.
reinit_n_eval : int, optional
Used if model is reinitialized to different number of nodes per graph.
Returns
-------
A tensorflow dataset.
"""
np.random.seed(seed)
if reinit_n_eval is not None:
print(
"ATTENTION: specifying reinit_n_eval will change class argument n_eval_nodes_per_graph "
"from %i to %i" % (self.n_eval_nodes_per_graph, reinit_n_eval)
)
self.n_eval_nodes_per_graph = reinit_n_eval
def generator():
for key in image_keys:
if nodes_idx[key].size == 0: # needed for images where no nodes are selected
continue
idx_nodes = np.arange(0, self.a[key].shape[0])
index_list = [
np.asarray(
nodes_idx[key][self.n_eval_nodes_per_graph * i : self.n_eval_nodes_per_graph * (i + 1)],
dtype=np.int32,
)
for i in range(len(nodes_idx[key]) // self.n_eval_nodes_per_graph)
]
resampled_index_list = [
np.asarray(
np.random.choice(
a=nodes_idx[key],
size=self.n_eval_nodes_per_graph,
replace=True,
),
dtype=np.int32,
)
for i in range(len(nodes_idx[key]) // self.n_eval_nodes_per_graph)
]
for i, indices in enumerate(index_list):
re_indices = resampled_index_list[i]
h_0 = self.h_0[key][idx_nodes]
diff = self.max_nodes - h_0.shape[0]
zeros = np.zeros((diff, h_0.shape[1]), dtype="float32")
h_0_full = np.asarray(np.concatenate((h_0, zeros), axis=0), dtype="float32")
re_h_0 = h_0_full[re_indices]
h_0 = h_0_full[indices]
h_1 = self.h_1[key][idx_nodes]
diff = self.max_nodes - h_1.shape[0]
zeros = np.zeros((diff, h_1.shape[1]), dtype="float32")
h_1 = np.asarray(np.concatenate((h_1, zeros), axis=0), dtype="float32")
re_h_1 = h_1[re_indices]
h_1 = h_1[indices]
if self.log_transform:
h_1 = np.log(h_1 + 1.0)
re_h_1 = np.log(re_h_1 + 1.0)
# indexing adjacency matrix that yield only selected cells (final graph layer)
a = self.a[key][idx_nodes, :][:, idx_nodes]
re_a = a[re_indices, :]
re_coo = re_a.tocoo()
re_a_ind = np.asarray(np.mat([re_coo.row, re_coo.col]).transpose(), dtype="int64")
re_a_val = np.asarray(re_coo.data, dtype="float32")
re_a_shape = np.asarray((self.n_eval_nodes_per_graph, self.max_nodes), dtype="int64")
re_a = tf.SparseTensor(indices=re_a_ind, values=re_a_val, dense_shape=re_a_shape)
a = a[indices, :]
coo = a.tocoo()
a_ind = np.asarray(np.mat([coo.row, coo.col]).transpose(), dtype="int64")
a_val = np.asarray(coo.data, dtype="float32")
a_shape = np.asarray((self.n_eval_nodes_per_graph, self.max_nodes), dtype="int64")
a = tf.SparseTensor(indices=a_ind, values=a_val, dense_shape=a_shape)
# propagating adjacency matrix that yield all cells (before final graph layer)
if self.model.args["cond_depth"] > 1:
a_full = self.a[key][idx_nodes, :][:, idx_nodes]
a_full = a_full.tocoo()
afull_ind = np.asarray(np.mat([a_full.row, a_full.col]).transpose(), dtype="int64")
afull_val = np.asarray(a_full.data, dtype="float32")
else:
afull_ind = np.asarray(np.mat([np.zeros((0,)), np.zeros((0,))]).transpose(), dtype="int64")
afull_val = np.asarray(np.zeros((0,)), dtype="float32")
afull_shape = np.asarray((self.max_nodes, self.max_nodes), dtype="int64")
a_full = tf.SparseTensor(indices=afull_ind, values=afull_val, dense_shape=afull_shape)
node_covar = self.node_covar[key][idx_nodes]
diff = self.max_nodes - node_covar.shape[0]
zeros = np.zeros((diff, node_covar.shape[1]))
node_covar = np.asarray(np.concatenate([node_covar, zeros], axis=0), dtype="float32")
re_node_covar = node_covar[re_indices, :]
node_covar = node_covar[indices, :]
sf = np.expand_dims(self.size_factors[key][idx_nodes], axis=1)
diff = self.max_nodes - sf.shape[0]
zeros = np.zeros((diff, sf.shape[1]))
sf = np.asarray(np.concatenate([sf, zeros], axis=0), dtype="float32")
re_sf = sf[re_indices, :]
sf = sf[indices, :]
g = np.zeros((self.n_domains,), dtype="int32")
g[self.domains[key]] = 1
if self.vi_model:
kl_dummy = np.zeros((self.n_eval_nodes_per_graph,), dtype="float32")
yield (h_1, sf, h_0, h_0_full, a, a_full, node_covar, g), (h_1, kl_dummy), (
re_h_1,
re_sf,
re_h_0,
h_0_full,
re_a,
a_full,
re_node_covar,
g,
), (re_h_1, kl_dummy)
else:
yield (h_1, sf, h_0, h_0_full, a, a_full, node_covar, g), h_1, (
re_h_1,
re_sf,
re_h_0,
h_0_full,
re_a,
a_full,
re_node_covar,
g,
), re_h_1
output_signature = self._get_output_signature(resampled=True)
dataset = tf.data.Dataset.from_generator(generator=generator, output_signature=output_signature)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(prefetch)
return dataset
[docs]class EstimatorNoGraph(Estimator):
"""EstimatorNoGraph class for baseline models."""
[docs] def init_model(self, **kwargs):
"""Initialize EstimatorNoGraph.
Parameters
----------
kwargs
Arbitrary keyword arguments.
"""
pass
def _get_output_signature(self, resampled: bool = False):
"""Get output signatures.
Parameters
----------
resampled : bool
Whether dataset is resampled or not.
Returns
-------
output_signature
"""
h_1 = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_features_1), dtype=tf.float32
) # input node features
sf = tf.TensorSpec(shape=(self.n_eval_nodes_per_graph, 1), dtype=tf.float32) # input node size factors
node_covar = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_node_covariates), dtype=tf.float32
) # node-level covariates
domain = tf.TensorSpec(shape=(self.n_domains,), dtype=tf.int32) # domain
reconstruction = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_features_1), dtype=tf.float32
) # node features to reconstruct
kl_dummy = tf.TensorSpec(shape=(self.n_eval_nodes_per_graph,), dtype=tf.float32) # dummy for kl loss
if self.vi_model:
if resampled:
output_signature = (
(h_1, sf, node_covar, domain),
(reconstruction, kl_dummy),
(h_1, sf, node_covar, domain),
(reconstruction, kl_dummy), # shapes for resampled output
)
else:
output_signature = ((h_1, sf, node_covar, domain), (reconstruction, kl_dummy))
else:
if resampled:
output_signature = (
(h_1, sf, node_covar, domain),
reconstruction,
(h_1, sf, node_covar, domain),
reconstruction, # shapes for resampled output
)
else:
output_signature = ((h_1, sf, node_covar, domain), reconstruction)
return output_signature
def _get_dataset(
self,
image_keys: List[str],
nodes_idx: Dict[str, np.ndarray],
batch_size: int,
shuffle_buffer_size: Optional[int],
train: bool = True,
seed: Optional[int] = None,
prefetch: int = 100,
reinit_n_eval: Optional[int] = None,
):
"""Prepare a dataset.
Parameters
----------
image_keys : np.array
Image keys in partition.
nodes_idx : dict, str
Dictionary of nodes per image in partition.
batch_size : int
Batch size.
shuffle_buffer_size : int, optional
Shuffle buffer size.
train : bool
Whether dataset is used for training or not (influences shuffling of nodes).
seed : int, optional
Random seed.
prefetch: int
Prefetch of dataset.
reinit_n_eval : int, optional
Used if model is reinitialized to different number of nodes per graph.
Returns
-------
A tensorflow dataset.
"""
np.random.seed(seed)
if reinit_n_eval is not None and reinit_n_eval != self.n_eval_nodes_per_graph:
print(
"ATTENTION: specifying reinit_n_eval will change class argument n_eval_nodes_per_graph "
"from %i to %i" % (self.n_eval_nodes_per_graph, reinit_n_eval)
)
self.n_eval_nodes_per_graph = reinit_n_eval
def generator():
for key in image_keys:
if nodes_idx[key].size == 0: # needed for images where no nodes are selected
continue
idx_nodes = np.arange(0, self.a[key].shape[0])
if train:
index_list = [
np.asarray(
np.random.choice(
a=nodes_idx[key],
size=self.n_eval_nodes_per_graph,
replace=True,
),
dtype=np.int32,
)
]
else:
# dropping
index_list = [
np.asarray(
nodes_idx[key][self.n_eval_nodes_per_graph * i : self.n_eval_nodes_per_graph * (i + 1)],
dtype=np.int32,
)
for i in range(len(nodes_idx[key]) // self.n_eval_nodes_per_graph)
]
for indices in index_list:
h_1 = self.h_1[key][idx_nodes]
diff = self.max_nodes - h_1.shape[0]
zeros = np.zeros((diff, h_1.shape[1]))
h_1 = np.asarray(np.concatenate((h_1, zeros), axis=0), dtype="float32")
h_1 = h_1[indices]
if self.log_transform:
h_1 = np.log(h_1 + 1.0)
node_covar = self.node_covar[key][idx_nodes]
diff = self.max_nodes - node_covar.shape[0]
zeros = np.zeros((diff, node_covar.shape[1]))
node_covar = np.asarray(np.concatenate([node_covar, zeros], axis=0), dtype="float32")
node_covar = node_covar[indices]
sf = np.expand_dims(self.size_factors[key][idx_nodes], axis=1)
diff = self.max_nodes - sf.shape[0]
zeros = np.zeros((diff, sf.shape[1]))
sf = np.asarray(np.concatenate([sf, zeros], axis=0), dtype="float32")
sf = sf[indices, :]
g = np.zeros((self.n_domains,), dtype="int32")
g[self.domains[key]] = 1
if self.vi_model:
kl_dummy = np.zeros((self.n_eval_nodes_per_graph,), dtype="float32")
yield (h_1, sf, node_covar, g), (h_1, kl_dummy)
else:
yield (h_1, sf, node_covar, g), h_1
output_signature = self._get_output_signature(resampled=False)
dataset = tf.data.Dataset.from_generator(generator=generator, output_signature=output_signature)
if train:
if shuffle_buffer_size is not None:
dataset = dataset.shuffle(buffer_size=shuffle_buffer_size, seed=None, reshuffle_each_iteration=True)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(prefetch)
return dataset
def _get_resampled_dataset(
self,
image_keys: np.ndarray,
nodes_idx: dict,
batch_size: int,
seed: Optional[int] = None,
prefetch: int = 100,
reinit_n_eval: Optional[int] = None,
):
"""Evaluate model based on resampled dataset for posterior resampling.
node_1 + domain_1 -> encoder -> z_1 + domain_2 -> decoder -> reconstruction_2.
Parameters
----------
image_keys: np.array
Image keys in partition.
nodes_idx : dict
Dictionary of nodes per image in partition.
batch_size : int
Batch size.
seed : int, optional
Seed.
prefetch : int
Prefetch.
reinit_n_eval : int, optional
Used if model is reinitialized to different number of nodes per graph.
Returns
-------
A tensorflow dataset.
"""
np.random.seed(seed)
if reinit_n_eval is not None:
print(
"ATTENTION: specifying reinit_n_eval will change class argument n_eval_nodes_per_graph "
"from %i to %i" % (self.n_eval_nodes_per_graph, reinit_n_eval)
)
self.n_eval_nodes_per_graph = reinit_n_eval
def generator():
for key in image_keys:
if nodes_idx[key].size == 0: # needed for images where no nodes are selected
continue
idx_nodes = np.arange(0, self.a[key].shape[0])
index_list = [
np.asarray(
nodes_idx[key][self.n_eval_nodes_per_graph * i : self.n_eval_nodes_per_graph * (i + 1)],
dtype=np.int32,
)
for i in range(len(nodes_idx[key]) // self.n_eval_nodes_per_graph)
]
resampled_index_list = [
np.asarray(
np.random.choice(
a=nodes_idx[key],
size=self.n_eval_nodes_per_graph,
replace=True,
),
dtype=np.int32,
)
for i in range(len(nodes_idx[key]) // self.n_eval_nodes_per_graph)
]
for i, indices in enumerate(index_list):
re_indices = resampled_index_list[i]
h_1 = self.h_1[key][idx_nodes]
diff = self.max_nodes - h_1.shape[0]
zeros = np.zeros((diff, h_1.shape[1]))
h_1 = np.asarray(np.concatenate((h_1, zeros), axis=0), dtype="float32")
re_h_1 = h_1[re_indices]
h_1 = h_1[indices]
if self.log_transform:
h_1 = np.log(h_1 + 1.0)
re_h_1 = np.log(re_h_1 + 1.0)
node_covar = self.node_covar[key][idx_nodes]
diff = self.max_nodes - node_covar.shape[0]
zeros = np.zeros((diff, node_covar.shape[1]))
node_covar = np.asarray(np.concatenate([node_covar, zeros], axis=0), dtype="float32")
re_node_covar = node_covar[re_indices]
node_covar = node_covar[indices]
sf = np.expand_dims(self.size_factors[key][idx_nodes], axis=1)
diff = self.max_nodes - sf.shape[0]
zeros = np.zeros((diff, sf.shape[1]))
sf = np.asarray(np.concatenate([sf, zeros], axis=0), dtype="float32")
re_sf = sf[re_indices, :]
sf = sf[indices, :]
g = np.zeros((self.n_domains,), dtype="int32")
g[self.domains[key]] = 1
if self.vi_model:
kl_dummy = np.zeros((self.n_eval_nodes_per_graph,), dtype="float32")
yield (h_1, sf, node_covar, g), (h_1, kl_dummy), (re_h_1, re_sf, re_node_covar, g), (
re_h_1,
kl_dummy,
)
else:
yield (h_1, sf, node_covar, g), h_1, (re_h_1, re_sf, re_node_covar, g), re_h_1
output_signature = self._get_output_signature(resampled=True)
dataset = tf.data.Dataset.from_generator(generator=generator, output_signature=output_signature)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(prefetch)
return dataset