ott.solvers.nn.icnn.ICNN.init
ott.solvers.nn.icnn.ICNN.init#
- ICNN.init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#
Initializes a module method with variables and returns modified variables.
inittakes as first argument either a singlePRNGKey, or a dictionary mapping variable collections names to theirPRNGKeys, and will callmethod(which is the module’s__call__function by default) passing*argsand**kwargs, and returns a dictionary of initialized variables.Example:
>>> import flax.linen as nn >>> import jax.numpy as jnp >>> import jax ... >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(16)(x) ... x = nn.BatchNorm(use_running_average=not train)(x) ... x = nn.relu(x) ... return nn.Dense(1)(x) ... >>> module = Foo() >>> key = jax.random.PRNGKey(0) >>> variables = module.init(key, jnp.empty((1, 7)), train=False)
If you pass a single
PRNGKey, Flax will use it to feed the'params'RNG stream. If you want to use a different RNG stream or need to use multiple streams, you must pass a dictionary mapping each RNG stream name to its correspondingPRNGKeytoinit.Example:
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(16)(x) ... x = nn.BatchNorm(use_running_average=not train)(x) ... x = nn.relu(x) ... ... # Add gaussian noise ... noise_key = self.make_rng('noise') ... x = x + jax.random.normal(noise_key, x.shape) ... ... return nn.Dense(1)(x) ... >>> module = Foo() >>> rngs = {'params': jax.random.PRNGKey(0), 'noise': jax.random.PRNGKey(1)} >>> variables = module.init(rngs, jnp.empty((1, 7)), train=False)
Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:
>>> module = nn.Dense(1) >>> init_jit = jax.jit(module.init) >>> variables = init_jit(jax.random.PRNGKey(0), jnp.empty((1, 7)))
initis a light wrapper overapply, so otherapplyarguments likemethod,mutable, andcapture_intermediatesare also available.- Parameters
rngs (
Union[Any,Dict[str,Any]]) – The rngs for the variable collections.*args – Named arguments passed to the init function.
method (
Union[Callable[...,Any],str,None]) – An optional method. If provided, applies this method. If not provided, applies the__call__method. A string can also be provided to specify a method by name.mutable (
Union[bool,str,Collection[str],DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable:bool: all/no collections are mutable.str: The name of a single mutable collection.list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.capture_intermediates (
Union[bool,Callable[[Module,str],bool]]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all__call__methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.**kwargs – Keyword arguments passed to the init function.
- Return type
FrozenDict[str,Mapping[str,Any]]- Returns
The initialized variable dict.