Source code for ott.solvers.nn.icnn

# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of :cite:`amos:17` input convex neural networks (ICNN)."""

from typing import Any, Callable, Sequence, Tuple, Union

import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from jax.nn import initializers

from ott.math import matrix_square_root
from ott.solvers.nn.layers import PosDefPotentials, PositiveDense

__all__ = ["ICNN"]

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any
Array = Any


[docs]class ICNN(nn.Module): """Input convex neural network (ICNN) architecture with initialization. Implementation of input convex neural networks as introduced in :cite:`amos:17` with initialization schemes proposed by :cite:`bunne:22`. Args: dim_data: data dimensionality. dim_hidden: sequence specifying size of hidden dimensions. The output dimension of the last layer is 1 by default. init_std: value of standard deviation of weight initialization method. init_fn: choice of initialization method for weight matrices (default: `jax.nn.initializers.normal`). act_fn: choice of activation function used in network architecture (needs to be convex, default: `nn.relu`). pos_weights: choice to enforce positivity of weight or use regularizer. gaussian_map: data inputs of source and target measures for initialization scheme based on Gaussian approximation of input and target measure (if None, identity initialization is used). """ dim_data: int dim_hidden: Sequence[int] init_std: float = 1e-1 init_fn: Callable = jax.nn.initializers.normal act_fn: Callable = nn.relu pos_weights: bool = True gaussian_map: Tuple[jnp.ndarray, jnp.ndarray] = None
[docs] def setup(self) -> None: self.num_hidden = len(self.dim_hidden) if self.pos_weights: hid_dense = PositiveDense # this function needs to be the inverse map of function # used in PositiveDense layers rescale = hid_dense.inv_rectifier_fn else: hid_dense = nn.Dense rescale = lambda x: x self.use_init = False # check if Gaussian map was provided if self.gaussian_map is not None: factor, mean = self._compute_gaussian_map(self.gaussian_map) else: factor, mean = self._compute_identity_map(self.dim_data) w_zs = [] # keep track of previous size to normalize accordingly normalization = 1 # subsequent layers propagate value of potential provided by # first layer in x normalization factor is rescaled accordingly for i in range(self.num_hidden): w_zs.append( hid_dense( self.dim_hidden[i], kernel_init=initializers.constant(rescale(1.0 / normalization)), use_bias=False, ) ) normalization = self.dim_hidden[i] # final layer computes average, still with normalized rescaling w_zs.append( hid_dense( 1, kernel_init=initializers.constant(rescale(1.0 / normalization)), use_bias=False, ) ) self.w_zs = w_zs w_xs = [] # first square layer, initialized to identity w_xs.append( PosDefPotentials( self.dim_data, num_potentials=1, kernel_init=lambda *args, **kwargs: factor, bias_init=lambda *args, **kwargs: mean, use_bias=True, ) ) # subsequent layers reinjected into convex functions for i in range(self.num_hidden): w_xs.append( nn.Dense( self.dim_hidden[i], kernel_init=self.init_fn(self.init_std), bias_init=self.init_fn(self.init_std), use_bias=True, ) ) # final layer, to output number w_xs.append( nn.Dense( 1, kernel_init=self.init_fn(self.init_std), bias_init=self.init_fn(self.init_std), use_bias=True, ) ) self.w_xs = w_xs
@staticmethod def _compute_gaussian_map( inputs: Tuple[jnp.ndarray, jnp.ndarray] ) -> Tuple[jnp.ndarray, jnp.ndarray]: def compute_moments(x, reg=1e-4, sqrt_inv=False): shape = x.shape z = x.reshape(shape[0], -1) mu = jnp.expand_dims(jnp.mean(z, axis=0), 0) z = z - mu matmul = lambda a, b: jnp.matmul(a, b) sigma = jax.vmap(matmul)(jnp.expand_dims(z, 2), jnp.expand_dims(z, 1)) # unbiased estimate sigma = jnp.sum(sigma, axis=0) / (shape[0] - 1) # regularize sigma = sigma + reg * jnp.eye(shape[1]) if sqrt_inv: sigma_sqrt, sigma_inv_sqrt, _ = matrix_square_root.sqrtm(sigma) return sigma, sigma_sqrt, sigma_inv_sqrt, mu else: return sigma, mu source, target = inputs _, covs_sqrt, covs_inv_sqrt, mus = compute_moments(source, sqrt_inv=True) covt, mut = compute_moments(target, sqrt_inv=False) mo = matrix_square_root.sqrtm_only( jnp.dot(jnp.dot(covs_sqrt, covt), covs_sqrt) ) A = jnp.dot(jnp.dot(covs_inv_sqrt, mo), covs_inv_sqrt) b = jnp.squeeze(mus) - jnp.linalg.solve(A, jnp.squeeze(mut)) A = matrix_square_root.sqrtm_only(A) return jnp.expand_dims(A, 0), jnp.expand_dims(b, 0) @staticmethod def _compute_identity_map(input_dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]: A = jnp.eye(input_dim).reshape((1, input_dim, input_dim)) b = jnp.zeros((1, input_dim)) return A, b @nn.compact def __call__(self, x: jnp.ndarray) -> float: for i in range(self.num_hidden + 2): if i == 0: z = self.w_xs[i](x) # apply both transform on hidden state and x # x is one step ahead as there is one more hidden layer for x else: z = jnp.add(self.w_zs[i - 1](z), self.w_xs[i](x)) if i != 0 or i != self.num_hidden + 1: z = self.act_fn(z) return jnp.squeeze(z)
[docs] def create_train_state( self, rng: jnp.ndarray, optimizer: optax.OptState, input: Union[int, Tuple[int, ...]], ) -> train_state.TrainState: """Create initial `TrainState`.""" params = self.init(rng, jnp.ones(input))["params"] return train_state.TrainState.create( apply_fn=self.apply, params=params, tx=optimizer )