leniax.core package
Leniax core simulation functions
- leniax.core.update(rng_key, state, K, gf_params, kernels_weight_per_channel, dt, get_potential_fn, get_field_fn, get_state_fn)[source]
Update the cells state
Jitted function with static argnums. Use functools.partial to set the different function. Avoid changing non-static argument shape for performance.
- Parameters
rng_key (jax._src.prng.PRNGKeyArray) – JAX PRNG key.
state (jax._src.numpy.ndarray.ndarray) – cells state
[N, nb_channels, world_dims...]K (jax._src.numpy.ndarray.ndarray) – Kernel
[K_o=nb_channels * max_k_per_channel, K_i=1, kernel_dims...]gf_params (jax._src.numpy.ndarray.ndarray) – Growth function parmaeters
[nb_kernels, params_shape...]kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) – Kernels weight used in the averaginf function
[nb_channels, nb_kernels]dt (jax._src.numpy.ndarray.ndarray) – Update rate
[N]get_potential_fn (Callable) – (jit static arg) Function used to compute the potential
get_field_fn (Callable) – (jit static arg) Function used to compute the field
get_state_fn (Callable[[jax._src.prng.PRNGKeyArray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray], jax._src.numpy.ndarray.ndarray]) – (jit static arg) Function used to compute the new cell state
- Returns
A tuple of arrays representing a jax PRNG key and the updated state, the used potential and used field.
- Return type
Tuple[jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray]
- leniax.core.get_potential_fft(state, K, tc_indices=None, channel_first=True)[source]
Compute the potential using FFT
The first dimension of cells and K is the vmap dimension
- Parameters
state (jax._src.numpy.ndarray.ndarray) – cells state
[N, nb_channels, world_dims...]K (jax._src.numpy.ndarray.ndarray) – Kernels
[1, nb_channels, max_k_per_channel, world_dims...]tc_indices (Optional[Tuple]) – Optional
1-dimarray channels indices to keepchannel_first (bool) – see
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html
- Returns
An array containing the potential
- Return type
jax._src.numpy.ndarray.ndarray
- leniax.core.get_potential(state, K, padding, tc_indices=None, channel_first=True)[source]
Compute the potential using lax.conv_general_dilated
The first dimension of cells and K is the vmap dimension
- Parameters
state (jax._src.numpy.ndarray.ndarray) – cells state
[N, nb_channels, world_dims...]K (jax._src.numpy.ndarray.ndarray) – Kernels
[K_o=nb_channels * max_k_per_channel, K_i=1, kernel_dims...]padding (Tuple) – array with padding informations,
[nb_world_dims, 2]tc_indices (Optional[Tuple]) – Optional
1-dimarray channels indices to keepchannel_first (bool) – see
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html
- Returns
An array containing the potential
- Return type
jax._src.numpy.ndarray.ndarray
- leniax.core.get_field(potential, gf_params, kernels_weight_per_channel, growth_fn_t, weighted_fn)[source]
Compute the field
Jitted function with static argnums. Use functools.partial to set the different function. Avoid changing non-static argument shape for performance.
- Parameters
potential (jax._src.numpy.ndarray.ndarray) –
[N, nb_kernels, world_dims...]gf_params (jax._src.numpy.ndarray.ndarray) –
[nb_kernels, nb_gf_params]kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) – Kernels weight used in the averaginf function
[nb_channels, nb_kernels]growth_fn_t (Tuple[Callable, ...]) – (jit static arg) Tuple of growth functions.
length: nb_kernelsweighted_fn (Callable) – (jit static arg) Function used to merge fields linked to the same channel
- Returns
An array containing the field
- Return type
jax._src.numpy.ndarray.ndarray
- leniax.core.weighted_sum(fields, weights)[source]
Compute the weighted sum of sub fields
- Parameters
fields (jax._src.numpy.ndarray.ndarray) – Raw sub fields
[N, nb_kernels, world_dims...]weights (jax._src.numpy.ndarray.ndarray) – Weights used to compute the sum
[nb_channels, nb_kernels]0. values are used to indicate that a given channels does not receive inputs from this kernel
- Returns
The unnormalized field
- Return type
jax._src.numpy.ndarray.ndarray
- leniax.core.weighted_mean(fields, weights)[source]
Compute the weighted mean of sub fields
- Parameters
fields (jax._src.numpy.ndarray.ndarray) – Raw sub fields
[N, nb_kernels, world_dims...]weights (jax._src.numpy.ndarray.ndarray) – Weights used to compute the sum
[nb_channels, nb_kernels]0. values are used to indicate that a given channels does not receive inputs from this kernel
- Returns
The normalized field
- Return type
jax._src.numpy.ndarray.ndarray