leniax.runner package

leniax.runner.run(rng_key, cells, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn, stat_trunc=False)[source]

Simulate a single configuration

It uses a python for loop under the hood.

Parameters
  • rng_key (jax._src.prng.PRNGKeyArray) – JAX PRNG key.

  • cells (jax._src.numpy.ndarray.ndarray) – Initial cells state [N_init=1, nb_channels, world_dims...]

  • K (jax._src.numpy.ndarray.ndarray) – Stacked Kernels (shape depends on convolution implementation)

  • gf_params (jax._src.numpy.ndarray.ndarray) – Growth function parameters [nb_kernels, params_shape...]

  • kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) – Kernels weight used in the average function [nb_channels, nb_kernels]

  • dt – Update rate [1]

  • max_run_iter (int) – Maximum number of simulation iterations

  • R (float) – Main kernel Resolution

  • update_fn (Callable) – Function used to compute the new cell state

  • compute_stats_fn (Callable) – Function used to compute the statistics

  • stat_trunc (bool) – Set to True to truncate run based on its statistics

  • T (jax._src.numpy.ndarray.ndarray) –

Returns

A 5-tuple of arrays representing a jax PRNG key, the updated cells state, the used potential and used field and statistics

Return type

Tuple[jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, Dict[str, jax._src.numpy.ndarray.ndarray]]

leniax.runner.run_scan(rng_key, cells0, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn)[source]

Simulate a single configuration

This function is jitted, it uses jax.lax.scan function under the hood. It can be used to simulate a single configuration with multiple initialization.

Parameters
  • rng_key (jax._src.prng.PRNGKeyArray) – JAX PRNG key.

  • cells0 (jax._src.numpy.ndarray.ndarray) – Initial cells state [N_init, nb_channels, world_dims...]

  • K (jax._src.numpy.ndarray.ndarray) – Stacked Kernels [kernel_shape...]

  • gf_params (jax._src.numpy.ndarray.ndarray) – Growth function parameters [nb_kernels, params_shape...]

  • kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) – Kernels weight used in the average function [nb_channels, nb_kernels]

  • dt – Update rate [1]

  • max_run_iter (int) – Maximum number of simulation iterations

  • R (float) – Main kernel Resolution

  • update_fn (Callable) – Function used to compute the new cell state

  • compute_stats_fn (Callable) – Function used to compute the statistics

  • T (jax._src.numpy.ndarray.ndarray) –

Returns

A 4-tuple of arrays representing the updated cells state, the used potential and used field and simulations statistics

Return type

Tuple[jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, Dict[str, jax._src.numpy.ndarray.ndarray]]

leniax.runner.run_scan_mem_optimized(rng_key, cells0, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn)[source]

Vectorized version of run_scan_mem_optimized. Takes similar arguments as run_scan_mem_optimized but with additional array axes over which run_scan_mem_optimized is mapped.

Original documentation:

Simulate multiple configurations

This function is jitted, it uses jax.lax.scan function under the hood. It can be used to simulate multiple configurations with multiple initialization.

Args:

rng_key: JAX PRNG key. cells0: Initial cells state [N_sols, N_init, nb_channels, world_dims...] K: Stacked Kernels [N_sols, kernel_shape...] gf_params: Growth function parameters [N_sols, nb_kernels, params_shape...] kernels_weight_per_channel: Kernels weight used in the average function [N_sols, nb_channels, nb_kernels] T: Update rate [N_sols] max_run_iter: Maximum number of simulation iterations R: Main kernel Resolution update_fn: Function used to compute the new cell state compute_stats_fn: Function used to compute the statistics

Returns:

A 3-tuple representing a jax PRNG key, the simulations statistics and final cells states

Parameters
  • rng_key (jax._src.prng.PRNGKeyArray) –

  • cells0 (jax._src.numpy.ndarray.ndarray) –

  • K (jax._src.numpy.ndarray.ndarray) –

  • gf_params (jax._src.numpy.ndarray.ndarray) –

  • kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) –

  • T (jax._src.numpy.ndarray.ndarray) –

  • max_run_iter (int) –

  • R (float) –

  • update_fn (Callable) –

  • compute_stats_fn (Callable) –

Return type

Tuple[Dict[str, jax._src.numpy.ndarray.ndarray], jax._src.numpy.ndarray.ndarray]

leniax.runner.run_scan_mem_optimized_pmap(rng_key, cells0, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn)[source]

Vectorized version of run_scan_mem_optimized_pmap. Takes similar arguments as run_scan_mem_optimized_pmap but with additional array axes over which run_scan_mem_optimized_pmap is mapped.

Original documentation:

Simulate multiple configurations on multiple devices

This function is jitted, it uses jax.lax.scan function under the hood. It can be used to simulate multiple configurations with multiple initialization on multiple devices.

Args:

rng_key: JAX PRNG key. cells0: Initial cells state [N_device, N_sols, N_init, nb_channels, world_dims...] K: Stacked Kernels [N_device, N_sols, kernel_shape...] gf_params: Growth function parameters [N_device, N_sols, nb_kernels, params_shape...] kernels_weight_per_channel: Kernels weight used in the average function [N_device, N_sols, nb_channels, nb_kernels] T: Update rate [N_device, N_sols] max_run_iter: Maximum number of simulation iterations R: Main kernel Resolution update_fn: Function used to compute the new cell state compute_stats_fn: Function used to compute the statistics

Returns:

A 3-tuple representing a jax PRNG key, the simulations statistics and final cells states

Parameters
  • rng_key (jax._src.prng.PRNGKeyArray) –

  • cells0 (jax._src.numpy.ndarray.ndarray) –

  • K (jax._src.numpy.ndarray.ndarray) –

  • gf_params (jax._src.numpy.ndarray.ndarray) –

  • kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) –

  • T (jax._src.numpy.ndarray.ndarray) –

  • max_run_iter (int) –

  • R (float) –

  • update_fn (Callable) –

  • compute_stats_fn (Callable) –

Return type

Tuple[Dict[str, jax._src.numpy.ndarray.ndarray], jax._src.numpy.ndarray.ndarray]