Source code for ncem.models.model_cvae_ncem

from typing import Union

import numpy as np
import tensorflow as tf

from ncem.models.layers import (CondDecoder, CondEncoder, GCNLayer, MaxLayer, get_out, PreprocInput, SamplingPrior)


[docs]class ModelCVAEncem: """Model class for NCEM conditional variational autoencoder with graph layer IND (MAX) or GCN.""" def __init__( self, input_shapes, latent_dim: int = 10, intermediate_dim_enc: int = 128, intermediate_dim_dec: int = 128, depth_enc: int = 1, depth_dec: int = 1, dropout_rate: float = 0.1, l2_coef: float = 0.0, l1_coef: float = 0.0, cond_type: str = "gcn", cond_depth: int = 1, cond_dim: int = 8, cond_dropout_rate: float = 0.1, cond_activation: Union[str, tf.keras.layers.Layer] = "relu", cond_l2_reg: float = 0.0, cond_use_bias: bool = True, use_domain: bool = False, scale_node_size: bool = False, use_type_cond: bool = True, use_batch_norm: bool = False, transform_input: bool = False, output_layer: str = "gaussian", **kwargs, ): """Initialize conditional variational autoencoder NCEM model. Parameters ---------- input_shapes Input shapes. latent_dim : int Latent dimension. dropout_rate : float Dropout rate. l2_coef : float l2 regularization coefficient. l1_coef : float l1 regularization coefficient. intermediate_dim_enc : int Encoder intermediate dimension. depth_enc : int Encoder depth. intermediate_dim_dec : int Decoder intermediate dimension. depth_dec : int Decoder depth. cond_type : str Graph conditional type. cond_depth : int Graph conditional depth. cond_dim : int Graph conditional dimension. cond_dropout_rate : float Graph conditional dropout rate. cond_activation : str Graph conditional activation. cond_l2_reg : float Graph conditional l2 regularization coefficient. cond_use_bias : bool Graph conditional use bias. use_domain : bool Whether to use domain information. use_batch_norm : bool Whether to use batch normalization. use_type_cond : bool whether to use the categorical cell type label in conditional. scale_node_size : bool Whether to scale output layer by node sizes. transform_input : bool Whether to transform input. output_layer : str Output layer. kwargs Arbitrary keyword arguments. Raises ------ ValueError If `cond_type` or `output_layer` is not recognized. """ super().__init__() self.args = { "input_shapes": input_shapes, "latent_dim": latent_dim, "intermediate_dim_enc": intermediate_dim_enc, "intermediate_dim_dec": intermediate_dim_dec, "depth_enc": depth_enc, "depth_dec": depth_dec, "dropout_rate": dropout_rate, "l2_coef": l2_coef, "l1_coef": l1_coef, "cond_type": cond_type, "cond_depth": cond_depth, "cond_dim": cond_dim, "cond_dropout_rate": cond_dropout_rate, "cond_activation": cond_activation if not isinstance(cond_activation, tf.keras.layers.Layer) else cond_activation.name, "cond_l2_reg": cond_l2_reg, "cond_use_bias": cond_use_bias, "use_domain": use_domain, "use_type_cond": use_type_cond, "scale_node_size": scale_node_size, "output_layer": output_layer, } in_node_feature_dim = input_shapes[0] out_node_feature_dim = input_shapes[1] graph_dim = input_shapes[2] in_node_dim = input_shapes[3] categ_condition_dim = input_shapes[4] domain_dim = input_shapes[5] # N - maximum number of nodes in graph # M - number of nodes in graph evaluated in each batch # F - number of node features in reconstruction space # P - number of node covariate dimensions # node_features - reconstruction: Input Tensor - shape=(None, M, F) input_x_reconstruct = tf.keras.Input( shape=(in_node_dim, out_node_feature_dim), name="node_features_reconstruct" ) # node size - reconstruction: Input Tensor - shape=(None, M, 1) input_node_size = tf.keras.Input(shape=(in_node_dim, 1), name="node_size_reconstruct") # node_features - conditions: Input Tensor - shape=(None, M, F) input_x_cond = tf.keras.Input(shape=(in_node_dim, in_node_feature_dim), name="node_features_cond") # node features - node representation of other nodes: Input Tensor - shape=(None, N, F) input_x_cond_full = tf.keras.Input(shape=(graph_dim, in_node_feature_dim), name="node_features_cond_full") # adj_matrices - A: Input Tensor - shape=(None, M, N) input_a = tf.keras.Input(shape=(in_node_dim, graph_dim), name="adjacency_matrix", sparse=True) # full adj_matrices - A: Input Tensor - shape=(None, N, N) input_afull = tf.keras.Input(shape=(graph_dim, graph_dim), name="adjacency_matrix_full", sparse=True) # Categorical predictors: Input Tensor - shape=(None, M, P) input_categ_condition = tf.keras.Input(shape=(in_node_dim, categ_condition_dim), name="categorical_predictor") # domain information of graph - shape=(None, 1) input_g = tf.keras.layers.Input(shape=(domain_dim,), name="input_da_group", dtype="int32") # Decoder inputs: # 1) Sample in mode: latent_sampling1 = SamplingPrior(width=latent_dim)(input_x_reconstruct) latent_sampling_reshaped1 = tf.reshape(latent_sampling1, [-1, latent_dim]) # 2) Sample in data intput: input_latent_sampling2 = tf.keras.Input(shape=(in_node_dim, latent_dim), name="z_sampling") latent_sampling_reshaped2 = tf.reshape(input_latent_sampling2, [-1, latent_dim]) if use_domain: categ_condition = tf.concat( [ input_categ_condition, tf.tile(tf.expand_dims(tf.cast(input_g, dtype="float32"), axis=-2), [1, in_node_dim, 1]), ], axis=-1, ) else: categ_condition = input_categ_condition if cond_depth > 1: print("using multi layer graph model") x_neighbour_embedding = input_x_cond_full if cond_type == "gcn": for i in range(cond_depth - 1): cond_layer = GCNLayer( output_dim=cond_dim, dropout_rate=cond_dropout_rate, activation=cond_activation, l2_reg=cond_l2_reg, use_bias=cond_use_bias, name=f"conditional_layer_stack_{i}", ) x_neighbour_embedding = cond_layer([x_neighbour_embedding, input_afull]) # In last forward pass, only retain nodes that are to be decoded: not using the full adjacency matrix: cond_layer = GCNLayer( output_dim=cond_dim, dropout_rate=cond_dropout_rate, activation=cond_activation, l2_reg=cond_l2_reg, use_bias=cond_use_bias, name=f"conditional_layer_stack_{cond_depth}", ) x_neighbour_embedding = cond_layer([x_neighbour_embedding, input_a]) elif cond_type == "max": for i in range(cond_depth - 1): cond_layer = MaxLayer(name=f"conditional_layer_stack_{i}") x_neighbour_embedding = cond_layer([x_neighbour_embedding, input_afull]) # In last forward pass, only retain nodes that are to be decoded: not using the full adjacency matrix: cond_layer = MaxLayer(name=f"conditional_layer_stack_{cond_depth}") x_neighbour_embedding = cond_layer([x_neighbour_embedding, input_a]) else: raise ValueError("tried to access a non-supported conditional layer %s" % cond_type) x = PreprocInput()(input_x_reconstruct) if transform_input else input_x_reconstruct self.encoder_model = CondEncoder( latent_dim=latent_dim, intermediate_dim=intermediate_dim_enc, dropout_rate=dropout_rate, n_hidden=depth_enc, l1_coef=l1_coef, l2_coef=l2_coef, use_type_cond=use_type_cond, use_batch_norm=use_batch_norm, probabilistic=True, ) output_encoder = self.encoder_model((x, x_neighbour_embedding, categ_condition)) # output_encoder z, z_mean, z_log_var = output_encoder latent_space = tf.keras.layers.Concatenate(axis=1, name="bottleneck")([z, z_mean, z_log_var]) latent_space_sampling = tf.zeros_like(latent_space, name="bottleneck") latent_space2 = tf.keras.layers.Concatenate(axis=1, name="bottleneck")( [ # immitate latent_space tensor tf.zeros_like(input_latent_sampling2), tf.zeros_like(input_latent_sampling2), tf.zeros_like(input_latent_sampling2), ] ) self.decoder_model = CondDecoder( intermediate_dim=intermediate_dim_dec, dropout_rate=dropout_rate, n_hidden=depth_dec, l1_coef=l1_coef, l2_coef=l2_coef, use_type_cond=use_type_cond, use_batch_norm=use_batch_norm, ) output_decoder = self.decoder_model((z, x_neighbour_embedding, categ_condition)) sampling_decoder1 = self.decoder_model((latent_sampling_reshaped1, x_neighbour_embedding, categ_condition)) sampling_decoder2 = self.decoder_model((latent_sampling_reshaped2, x_neighbour_embedding, categ_condition)) output_decoder_layer = get_out( output_layer=output_layer, out_feature_dim=out_node_feature_dim, scale_node_size=scale_node_size )((output_decoder, input_node_size)) output_sampling_decoder1 = get_out( output_layer=output_layer, out_feature_dim=out_node_feature_dim, scale_node_size=scale_node_size, name='sampling1' )((sampling_decoder1, input_node_size)) output_sampling_decoder2 = get_out( output_layer=output_layer, out_feature_dim=out_node_feature_dim, scale_node_size=scale_node_size, name='sampling2' )((sampling_decoder2, input_node_size)) output_decoder_concat = tf.keras.layers.Concatenate(axis=2, name="reconstruction")(output_decoder_layer) output_sampling_concat1 = tf.keras.layers.Concatenate(axis=2, name="reconstruction")(output_sampling_decoder1) output_sampling_concat2 = tf.keras.layers.Concatenate(axis=2, name="reconstruction")(output_sampling_decoder2) self.graph_condition_model = tf.keras.Model( inputs=[ input_x_reconstruct, input_x_cond, input_x_cond_full, input_a, input_afull, input_categ_condition, input_g, ], outputs=x_neighbour_embedding, name="graph_conditional", ) self.encoder = tf.keras.Model( inputs=[ input_x_reconstruct, input_x_cond, input_x_cond_full, input_a, input_afull, input_categ_condition, input_g, ], outputs=output_encoder, name="encoder_ncem", ) self.decoder_sampling = tf.keras.Model( inputs=[ input_x_reconstruct, input_node_size, input_x_cond, input_x_cond_full, input_a, input_afull, input_categ_condition, input_g, ], outputs=[output_sampling_concat1, latent_space_sampling], name="decoder_sampling_ncem", ) self.decoder = tf.keras.Model( inputs=[ input_latent_sampling2, input_node_size, input_x_cond, input_x_cond_full, input_a, input_afull, input_categ_condition, input_g, ], outputs=[output_sampling_concat2, latent_space2], name="decoder_ncem", ) self.training_model = tf.keras.Model( inputs=[ input_x_reconstruct, input_node_size, input_x_cond, input_x_cond_full, input_a, input_afull, input_categ_condition, input_g, ], outputs=[output_decoder_concat, latent_space], name="cvae_ncem", ) # Add non-scaled ELBO to model as metric (ie no annealing or beta-VAE scaling): log2pi = tf.math.log(2.0 * np.pi) logqz_x = -0.5 * tf.reduce_mean(tf.square(z - z_mean) * tf.exp(-z_log_var) + z_log_var + log2pi) logpz = -0.5 * tf.reduce_mean(tf.square(z) + log2pi) d_kl = logqz_x - logpz loc, scale = output_decoder_layer if output_layer in ["gaussian", "gaussian_const_disp"]: neg_ll = tf.math.log(tf.sqrt(2 * np.math.pi) * scale) + 0.5 * tf.math.square( loc - input_x_reconstruct ) / tf.math.square(scale) elif output_layer in ["nb", "nb_const_disp", "nb_shared_disp"]: eta_loc = tf.math.log(loc) eta_scale = tf.math.log(scale) log_r_plus_mu = tf.math.log(scale + loc) ll = tf.math.lgamma(scale + input_x_reconstruct) ll = ll - tf.math.lgamma(input_x_reconstruct + tf.ones_like(input_x_reconstruct)) ll = ll - tf.math.lgamma(scale) ll = ( ll + tf.multiply(input_x_reconstruct, eta_loc - log_r_plus_mu) + tf.multiply(scale, eta_scale - log_r_plus_mu) ) neg_ll = -tf.clip_by_value(ll, -300, 300, "log_probs") else: neg_ll = None neg_ll = tf.reduce_mean(tf.reduce_sum(neg_ll, axis=-1)) self.training_model.add_metric(neg_ll + d_kl, name="elbo", aggregation="mean")