leniax.runner package
- leniax.runner.run(rng_key, cells, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn, stat_trunc=False)[source]
Simulate a single configuration
It uses a python
forloop under the hood.- Parameters
rng_key (jax._src.prng.PRNGKeyArray) – JAX PRNG key.
cells (jax._src.numpy.ndarray.ndarray) – Initial cells state
[N_init=1, nb_channels, world_dims...]K (jax._src.numpy.ndarray.ndarray) – Stacked Kernels (shape depends on convolution implementation)
gf_params (jax._src.numpy.ndarray.ndarray) – Growth function parameters
[nb_kernels, params_shape...]kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) – Kernels weight used in the average function
[nb_channels, nb_kernels]dt – Update rate
[1]max_run_iter (int) – Maximum number of simulation iterations
R (float) – Main kernel Resolution
update_fn (Callable) – Function used to compute the new cell state
compute_stats_fn (Callable) – Function used to compute the statistics
stat_trunc (bool) – Set to
Trueto truncate run based on its statisticsT (jax._src.numpy.ndarray.ndarray) –
- Returns
A 5-tuple of arrays representing a jax PRNG key, the updated cells state, the used potential and used field and statistics
- Return type
Tuple[jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, Dict[str, jax._src.numpy.ndarray.ndarray]]
- leniax.runner.run_scan(rng_key, cells0, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn)[source]
Simulate a single configuration
This function is jitted, it uses jax.lax.scan function under the hood. It can be used to simulate a single configuration with multiple initialization.
- Parameters
rng_key (jax._src.prng.PRNGKeyArray) – JAX PRNG key.
cells0 (jax._src.numpy.ndarray.ndarray) – Initial cells state
[N_init, nb_channels, world_dims...]K (jax._src.numpy.ndarray.ndarray) – Stacked Kernels
[kernel_shape...]gf_params (jax._src.numpy.ndarray.ndarray) – Growth function parameters
[nb_kernels, params_shape...]kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) – Kernels weight used in the average function
[nb_channels, nb_kernels]dt – Update rate
[1]max_run_iter (int) – Maximum number of simulation iterations
R (float) – Main kernel Resolution
update_fn (Callable) – Function used to compute the new cell state
compute_stats_fn (Callable) – Function used to compute the statistics
T (jax._src.numpy.ndarray.ndarray) –
- Returns
A 4-tuple of arrays representing the updated cells state, the used potential and used field and simulations statistics
- Return type
Tuple[jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, jax._src.numpy.ndarray.ndarray, Dict[str, jax._src.numpy.ndarray.ndarray]]
- leniax.runner.run_scan_mem_optimized(rng_key, cells0, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn)[source]
Vectorized version of run_scan_mem_optimized. Takes similar arguments as run_scan_mem_optimized but with additional array axes over which run_scan_mem_optimized is mapped.
Original documentation:
Simulate multiple configurations
This function is jitted, it uses jax.lax.scan function under the hood. It can be used to simulate multiple configurations with multiple initialization.
- Args:
rng_key: JAX PRNG key. cells0: Initial cells state
[N_sols, N_init, nb_channels, world_dims...]K: Stacked Kernels[N_sols, kernel_shape...]gf_params: Growth function parameters[N_sols, nb_kernels, params_shape...]kernels_weight_per_channel: Kernels weight used in the average function[N_sols, nb_channels, nb_kernels]T: Update rate[N_sols]max_run_iter: Maximum number of simulation iterations R: Main kernel Resolution update_fn: Function used to compute the new cell state compute_stats_fn: Function used to compute the statistics- Returns:
A 3-tuple representing a jax PRNG key, the simulations statistics and final cells states
- Parameters
rng_key (jax._src.prng.PRNGKeyArray) –
cells0 (jax._src.numpy.ndarray.ndarray) –
K (jax._src.numpy.ndarray.ndarray) –
gf_params (jax._src.numpy.ndarray.ndarray) –
kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) –
T (jax._src.numpy.ndarray.ndarray) –
max_run_iter (int) –
R (float) –
update_fn (Callable) –
compute_stats_fn (Callable) –
- Return type
Tuple[Dict[str, jax._src.numpy.ndarray.ndarray], jax._src.numpy.ndarray.ndarray]
- leniax.runner.run_scan_mem_optimized_pmap(rng_key, cells0, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn)[source]
Vectorized version of run_scan_mem_optimized_pmap. Takes similar arguments as run_scan_mem_optimized_pmap but with additional array axes over which run_scan_mem_optimized_pmap is mapped.
Original documentation:
Simulate multiple configurations on multiple devices
This function is jitted, it uses jax.lax.scan function under the hood. It can be used to simulate multiple configurations with multiple initialization on multiple devices.
- Args:
rng_key: JAX PRNG key. cells0: Initial cells state
[N_device, N_sols, N_init, nb_channels, world_dims...]K: Stacked Kernels[N_device, N_sols, kernel_shape...]gf_params: Growth function parameters[N_device, N_sols, nb_kernels, params_shape...]kernels_weight_per_channel: Kernels weight used in the average function[N_device, N_sols, nb_channels, nb_kernels]T: Update rate[N_device, N_sols]max_run_iter: Maximum number of simulation iterations R: Main kernel Resolution update_fn: Function used to compute the new cell state compute_stats_fn: Function used to compute the statistics- Returns:
A 3-tuple representing a jax PRNG key, the simulations statistics and final cells states
- Parameters
rng_key (jax._src.prng.PRNGKeyArray) –
cells0 (jax._src.numpy.ndarray.ndarray) –
K (jax._src.numpy.ndarray.ndarray) –
gf_params (jax._src.numpy.ndarray.ndarray) –
kernels_weight_per_channel (jax._src.numpy.ndarray.ndarray) –
T (jax._src.numpy.ndarray.ndarray) –
max_run_iter (int) –
R (float) –
update_fn (Callable) –
compute_stats_fn (Callable) –
- Return type
Tuple[Dict[str, jax._src.numpy.ndarray.ndarray], jax._src.numpy.ndarray.ndarray]