leniax.core package

Leniax core simulation functions

leniax.core.update(rng_key, state, K, gf_params, kernels_weight_per_channel, dt, get_potential_fn, get_field_fn, get_state_fn)[source]

Update the cells state

Jitted function with static argnums. Use functools.partial to set the different function. Avoid changing non-static argument shape for performance.

Parameters
  • rng_key (jax._src.prng.PRNGKeyArray) – JAX PRNG key.

  • state (jax._src.numpy.ndarray.ndarray) – cells state [N, nb_channels, world_dims...]

  • K (jax._src.numpy.ndarray.ndarray) – Kernel [K_o=nb_channels * max_k_per_channel, K_i=1, kernel_dims...]

  • gf_params (jax._src.numpy.ndarray.ndarray) – Growth function parmaeters [nb_kernels, params_shape...]

  • kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) – Kernels weight used in the averaginf function [nb_channels, nb_kernels]

  • dt (jax._src.numpy.ndarray.ndarray) – Update rate [N]

  • get_potential_fn (Callable) – (jit static arg) Function used to compute the potential

  • get_field_fn (Callable) – (jit static arg) Function used to compute the field

  • get_state_fn (Callable[[jax._src.prng.PRNGKeyArray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray], jax._src.numpy.ndarray.ndarray]) – (jit static arg) Function used to compute the new cell state

Returns

A tuple of arrays representing a jax PRNG key and the updated state, the used potential and used field.

Return type

Tuple[jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray]

leniax.core.get_potential_fft(state, K, tc_indices=None, channel_first=True)[source]

Compute the potential using FFT

The first dimension of cells and K is the vmap dimension

Parameters
  • state (jax._src.numpy.ndarray.ndarray) – cells state [N, nb_channels, world_dims...]

  • K (jax._src.numpy.ndarray.ndarray) – Kernels [1, nb_channels, max_k_per_channel, world_dims...]

  • tc_indices (Optional[Tuple]) – Optional 1-dim array channels indices to keep

  • channel_first (bool) – see https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html

Returns

An array containing the potential

Return type

jax._src.numpy.ndarray.ndarray

leniax.core.get_potential(state, K, padding, tc_indices=None, channel_first=True)[source]

Compute the potential using lax.conv_general_dilated

The first dimension of cells and K is the vmap dimension

Parameters
  • state (jax._src.numpy.ndarray.ndarray) – cells state [N, nb_channels, world_dims...]

  • K (jax._src.numpy.ndarray.ndarray) – Kernels [K_o=nb_channels * max_k_per_channel, K_i=1, kernel_dims...]

  • padding (Tuple) – array with padding informations, [nb_world_dims, 2]

  • tc_indices (Optional[Tuple]) – Optional 1-dim array channels indices to keep

  • channel_first (bool) – see https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html

Returns

An array containing the potential

Return type

jax._src.numpy.ndarray.ndarray

leniax.core.get_field(potential, gf_params, kernels_weight_per_channel, growth_fn_t, weighted_fn)[source]

Compute the field

Jitted function with static argnums. Use functools.partial to set the different function. Avoid changing non-static argument shape for performance.

Parameters
  • potential (jax._src.numpy.ndarray.ndarray) – [N, nb_kernels, world_dims...]

  • gf_params (jax._src.numpy.ndarray.ndarray) – [nb_kernels, nb_gf_params]

  • kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) – Kernels weight used in the averaginf function [nb_channels, nb_kernels]

  • growth_fn_t (Tuple[Callable, ...]) – (jit static arg) Tuple of growth functions. length: nb_kernels

  • weighted_fn (Callable) – (jit static arg) Function used to merge fields linked to the same channel

Returns

An array containing the field

Return type

jax._src.numpy.ndarray.ndarray

leniax.core.weighted_sum(fields, weights)[source]

Compute the weighted sum of sub fields

Parameters
  • fields (jax._src.numpy.ndarray.ndarray) – Raw sub fields [N, nb_kernels, world_dims...]

  • weights (jax._src.numpy.ndarray.ndarray) – Weights used to compute the sum [nb_channels, nb_kernels] 0. values are used to indicate that a given channels does not receive inputs from this kernel

Returns

The unnormalized field

Return type

jax._src.numpy.ndarray.ndarray

leniax.core.weighted_mean(fields, weights)[source]

Compute the weighted mean of sub fields

Parameters
  • fields (jax._src.numpy.ndarray.ndarray) – Raw sub fields [N, nb_kernels, world_dims...]

  • weights (jax._src.numpy.ndarray.ndarray) – Weights used to compute the sum [nb_channels, nb_kernels] 0. values are used to indicate that a given channels does not receive inputs from this kernel

Returns

The normalized field

Return type

jax._src.numpy.ndarray.ndarray