"""
.. highlight:: python
"""
import functools
import jax
from jax import vmap, pmap, lax, jit
from jax.tree_util import tree_map
import jax.numpy as jnp
from typing import Callable, Dict, Tuple
from . import statistics as leniax_stat
from .constant import EPSILON, START_CHECK_STOP
[docs]def run(
rng_key: jax.random.KeyArray,
cells: jnp.ndarray,
K: jnp.ndarray,
gf_params: jnp.ndarray,
kernels_weight_per_channel: jnp.ndarray,
T: jnp.ndarray,
max_run_iter: int,
R: float,
update_fn: Callable,
compute_stats_fn: Callable,
stat_trunc: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]]:
"""Simulate a single configuration
It uses a python ``for`` loop under the hood.
Args:
rng_key: JAX PRNG key.
cells: Initial cells state ``[N_init=1, nb_channels, world_dims...]``
K: Stacked Kernels (shape depends on convolution implementation)
gf_params: Growth function parameters ``[nb_kernels, params_shape...]``
kernels_weight_per_channel: Kernels weight used in the average function ``[nb_channels, nb_kernels]``
dt: Update rate ``[1]``
max_run_iter: Maximum number of simulation iterations
R: Main kernel Resolution
update_fn: Function used to compute the new cell state
compute_stats_fn: Function used to compute the statistics
stat_trunc: Set to ``True`` to truncate run based on its statistics
Returns:
A 5-tuple of arrays representing a jax PRNG key, the updated cells state,
the used potential and used field and statistics
"""
assert max_run_iter > 0, f"max_run_iter must be positive, value given: {max_run_iter}"
assert cells.shape[0] == 1
# cells shape: [N=1, C, dims...]
N = 1
nb_world_dims = cells.ndim - 2
all_cells = [cells]
all_fields = []
all_potentials = []
all_stats = []
init_mass = cells.sum()
previous_mass = init_mass
previous_sign = jnp.zeros(N, dtype=jnp.int32)
counters = leniax_stat.init_counters(N)
should_continue = jnp.ones(N, dtype=jnp.int32)
total_shift_idx = jnp.zeros([N, nb_world_dims], dtype=jnp.int32)
mass_centroid = jnp.zeros([nb_world_dims, N], dtype=jnp.int32)
mass_angle = jnp.zeros([N], dtype=jnp.int32)
subkeys = jax.random.split(rng_key, max_run_iter)
for current_iter in range(max_run_iter):
new_cells, field, potential = update_fn(subkeys[current_iter], cells, K, gf_params, kernels_weight_per_channel, 1. / T)
stat_t, total_shift_idx, mass_centroid, mass_angle = compute_stats_fn(
cells, field, potential, total_shift_idx, mass_centroid, mass_angle
)
cells = new_cells
all_cells.append(cells)
all_fields.append(field)
all_potentials.append(potential)
all_stats.append(stat_t)
mass = stat_t['mass']
cond = leniax_stat.min_mass_heuristic(EPSILON, mass)
should_continue_cond = cond
cond = leniax_stat.max_mass_heuristic(init_mass, mass)
should_continue_cond *= cond
sign = jnp.sign(mass - previous_mass)
monotone_counter = counters['nb_monotone_step']
cond, counters['nb_monotone_step'] = leniax_stat.monotonic_heuristic(sign, previous_sign, monotone_counter)
should_continue_cond *= cond
mass_volume = stat_t['mass_volume']
mass_volume_counter = counters['nb_max_volume_step']
cond, counters['nb_max_volume_step'] = leniax_stat.mass_volume_heuristic(mass_volume, mass_volume_counter)
should_continue_cond *= cond
should_continue *= should_continue_cond
# We avoid dismissing a simulation during the init period
if stat_trunc is True and current_iter >= START_CHECK_STOP and should_continue == 0:
break
# To keep the same number of elements per array
all_cells.pop()
all_cells_jnp = jnp.array(all_cells)
all_fields_jnp = jnp.array(all_fields)
all_potentials_jnp = jnp.array(all_potentials)
stats_dict = leniax_stat.stats_list_to_dict(all_stats)
stats_dict['N'] = jnp.array(current_iter)
return all_cells_jnp, all_fields_jnp, all_potentials_jnp, stats_dict
[docs]@functools.partial(jit, static_argnums=(6, 7, 8, 9))
def run_scan(
rng_key: jax.random.KeyArray,
cells0: jnp.ndarray,
K: jnp.ndarray,
gf_params: jnp.ndarray,
kernels_weight_per_channel: jnp.ndarray,
T: jnp.ndarray,
max_run_iter: int,
R: float,
update_fn: Callable,
compute_stats_fn: Callable
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]]:
"""Simulate a single configuration
This function is jitted, it uses jax.lax.scan function under the hood.
It can be used to simulate a single configuration with multiple initialization.
Args:
rng_key: JAX PRNG key.
cells0: Initial cells state ``[N_init, nb_channels, world_dims...]``
K: Stacked Kernels ``[kernel_shape...]``
gf_params: Growth function parameters ``[nb_kernels, params_shape...]``
kernels_weight_per_channel: Kernels weight used in the average function ``[nb_channels, nb_kernels]``
dt: Update rate ``[1]``
max_run_iter: Maximum number of simulation iterations
R: Main kernel Resolution
update_fn: Function used to compute the new cell state
compute_stats_fn: Function used to compute the statistics
Returns:
A 4-tuple of arrays representing the updated cells state, the used potential
and used field and simulations statistics
"""
init_carry = _get_init_carry(cells0, K, gf_params, kernels_weight_per_channel, T)
fn: Callable = functools.partial(
_scan_fn, update_fn=update_fn, compute_stats_fn=compute_stats_fn, keep_intermediary_data=True
)
subkeys = jax.random.split(rng_key, max_run_iter)
_, ys = lax.scan(fn, init_carry, jnp.array(subkeys))
continue_stat = leniax_stat.check_heuristics(ys['stats'])
ys['stats']['N'] = continue_stat.sum(axis=0)
return ys['cells'], ys['field'], ys['potential'], ys['stats']
[docs]@functools.partial(jit, static_argnums=(6, 7, 8, 9))
@functools.partial(vmap, in_axes=(None, 0, 0, 0, 0, 0, None, None, None, None), out_axes=0)
def run_scan_mem_optimized(
rng_key: jax.random.KeyArray,
cells0: jnp.ndarray,
K: jnp.ndarray,
gf_params: jnp.ndarray,
kernels_weight_per_channel: jnp.ndarray,
T: jnp.ndarray,
max_run_iter: int,
R: float,
update_fn: Callable,
compute_stats_fn: Callable
) -> Tuple[Dict[str, jnp.ndarray], jnp.ndarray]:
"""Simulate multiple configurations
This function is jitted, it uses jax.lax.scan function under the hood.
It can be used to simulate multiple configurations with multiple initialization.
Args:
rng_key: JAX PRNG key.
cells0: Initial cells state ``[N_sols, N_init, nb_channels, world_dims...]``
K: Stacked Kernels ``[N_sols, kernel_shape...]``
gf_params: Growth function parameters ``[N_sols, nb_kernels, params_shape...]``
kernels_weight_per_channel: Kernels weight used in the average function ``[N_sols, nb_channels, nb_kernels]``
T: Update rate ``[N_sols]``
max_run_iter: Maximum number of simulation iterations
R: Main kernel Resolution
update_fn: Function used to compute the new cell state
compute_stats_fn: Function used to compute the statistics
Returns:
A 3-tuple representing a jax PRNG key, the simulations statistics and final cells states
"""
init_carry = _get_init_carry(cells0, K, gf_params, kernels_weight_per_channel, T)
fn: Callable = functools.partial(
_scan_fn, update_fn=update_fn, compute_stats_fn=compute_stats_fn, keep_intermediary_data=False
)
subkeys = jax.random.split(rng_key, max_run_iter)
final_carry, ys = lax.scan(fn, init_carry, jnp.array(subkeys))
final_cells = final_carry['fn_params'][0]
stats = ys['stats']
continue_stat = leniax_stat.check_heuristics(stats)
stats['N'] = continue_stat.sum(axis=0)
return stats, final_cells
[docs]@functools.partial(
pmap, in_axes=(None, 0, 0, 0, 0, 0, None, None, None, None), out_axes=0, static_broadcasted_argnums=(6, 7, 8, 9)
)
@functools.partial(vmap, in_axes=(None, 0, 0, 0, 0, 0, None, None, None, None), out_axes=0)
def run_scan_mem_optimized_pmap(
rng_key: jax.random.KeyArray,
cells0: jnp.ndarray,
K: jnp.ndarray,
gf_params: jnp.ndarray,
kernels_weight_per_channel: jnp.ndarray,
T: jnp.ndarray,
max_run_iter: int,
R: float,
update_fn: Callable,
compute_stats_fn: Callable
) -> Tuple[Dict[str, jnp.ndarray], jnp.ndarray]:
"""Simulate multiple configurations on multiple devices
This function is jitted, it uses jax.lax.scan function under the hood.
It can be used to simulate multiple configurations with multiple initialization
on multiple devices.
Args:
rng_key: JAX PRNG key.
cells0: Initial cells state ``[N_device, N_sols, N_init, nb_channels, world_dims...]``
K: Stacked Kernels ``[N_device, N_sols, kernel_shape...]``
gf_params: Growth function parameters ``[N_device, N_sols, nb_kernels, params_shape...]``
kernels_weight_per_channel: Kernels weight used in the average function ``[N_device, N_sols, nb_channels, nb_kernels]``
T: Update rate ``[N_device, N_sols]``
max_run_iter: Maximum number of simulation iterations
R: Main kernel Resolution
update_fn: Function used to compute the new cell state
compute_stats_fn: Function used to compute the statistics
Returns:
A 3-tuple representing a jax PRNG key, the simulations statistics and final cells states
"""
init_carry = _get_init_carry(cells0, K, gf_params, kernels_weight_per_channel, T)
fn: Callable = functools.partial(
_scan_fn, update_fn=update_fn, compute_stats_fn=compute_stats_fn, keep_intermediary_data=False
)
subkeys = jax.random.split(rng_key, max_run_iter)
final_carry, ys = lax.scan(fn, init_carry, jnp.array(subkeys))
final_cells = final_carry['fn_params'][0]
stats = ys['stats']
continue_stat = leniax_stat.check_heuristics(stats)
stats['N'] = continue_stat.sum(axis=0)
return stats, final_cells
def _get_init_carry(
cells0: jnp.ndarray,
K: jnp.ndarray,
gf_params: jnp.ndarray,
kernels_weight_per_channel: jnp.ndarray,
T: jnp.ndarray,
with_stat: bool = True,
) -> Dict:
N = cells0.shape[0]
nb_world_dims = cells0.ndim - 2
init_carry: Dict = {
'fn_params': (cells0, K, gf_params, kernels_weight_per_channel, T),
}
if with_stat is True:
init_carry['stats_properties'] = {
'total_shift_idx': jnp.zeros([N, nb_world_dims], dtype=jnp.int32),
'mass_centroid': jnp.zeros([nb_world_dims, N]),
'mass_angle': jnp.zeros([N]),
}
return init_carry
@functools.partial(jit, static_argnums=(2, 3, 4, 5))
def _scan_fn(
carry: Dict,
x_rng_key: jnp.ndarray,
update_fn: Callable,
compute_stats_fn: Callable,
keep_intermediary_data: bool = False,
keep_all_timesteps: bool = True,
) -> Tuple[Dict, Dict]:
"""Update function used in the scan implementation"""
cells, K, gf_params, kernels_weight_per_channel, T = carry['fn_params']
new_cells, field, potential = update_fn(x_rng_key, cells, K, gf_params, kernels_weight_per_channel, 1. / T)
stat_props = carry['stats_properties']
total_shift_idx = stat_props['total_shift_idx']
mass_centroid = stat_props['mass_centroid']
mass_angle = stat_props['mass_angle']
stats, total_shift_idx, mass_centroid, mass_angle = compute_stats_fn(
cells, field, potential, total_shift_idx, mass_centroid, mass_angle
)
new_carry = {
'fn_params': (new_cells, K, gf_params, kernels_weight_per_channel, T),
'stats_properties': {
'total_shift_idx': total_shift_idx,
'mass_centroid': mass_centroid,
'mass_angle': mass_angle,
}
}
if keep_all_timesteps is True:
if keep_intermediary_data is True:
y = {'cells': cells, 'field': field, 'potential': potential, 'stats': stats}
else:
y = {'stats': stats}
else:
y = {}
return new_carry, y
@functools.partial(jit, static_argnums=(2, 3, 4))
def _scan_fn_without_stat(
state_carry: jnp.ndarray,
x_rng_key: jnp.ndarray,
update_fn: Callable,
keep_intermediary_data: bool = False,
keep_all_timesteps: bool = True,
) -> Tuple[jnp.ndarray, Dict]:
"""Update function used in the scan implementation"""
new_state_carry, field, potential = update_fn(x_rng_key, state_carry)
if keep_all_timesteps is True:
if keep_intermediary_data is True:
y = {'cells': state_carry, 'field': field, 'potential': potential}
else:
y = {'cells': state_carry}
else:
y = {}
return new_state_carry, y
###
# Differentiable functions
###
def make_pipeline_fn(
max_run_iter: int,
dt: float,
apply_fn: Callable,
loss_fn: Callable[[jax.random.KeyArray, Tuple, Tuple], Tuple[jnp.ndarray, ...]],
keep_intermediary_data: bool = False,
keep_all_timesteps: bool = True,
):
@jax.jit
def fn(rng_key, params, variables, state0, targets):
"""Simulate a single configuration
It uses a python for loop under the hood.
Compute the error using the provided `loss_fn`::
def loss_fn(...):
...
run_fn = functools.partial(
leniax_runner.run_diff,
nb_steps=nb_steps,
K_fn=K_fn,
update_fn=update_fn,
loss_fn=loss_fn
)
run_fn_value_and_grads = jax.value_and_grad(run_fn, argnums=(1))
Args:
cells: Initial cells state ``[N_init=1, nb_channels, world_dims...]``
K_params: Kernel parameters used to generate the final kernels
gf_params: Growth function parameters ``[nb_kernels, params_shape...]``
kernels_weight_per_channel: Kernels weight used in the average function ``[nb_channels, nb_kernels]``
dt: Update rate ``[N]``
target: 3-tuple of optionnal true cells, potential and field values
nb_steps: Maximum number of simulation iterations
K_fn: Function used to compute the kernels
update_fn: Function used to compute the new cell state
loss_fn: Function used to compute the error
Returns:
The error value.
"""
update_fn = functools.partial(apply_fn, {'params': params, **variables}, dt=dt)
fn: Callable = functools.partial(
_scan_fn_without_stat,
update_fn=update_fn,
keep_intermediary_data=keep_intermediary_data,
keep_all_timesteps=keep_all_timesteps,
)
# The number of iterations is controlled by the number of PRNG keys
rng_key, *subkeys = jax.random.split(rng_key, max_run_iter + 1)
final_state_carry, preds = lax.scan(fn, state0, jnp.array(subkeys))
if keep_all_timesteps is False:
preds = {'cells': final_state_carry}
loss, *rest = loss_fn(rng_key, preds, targets)
return loss, rest
return fn
def make_gradient_fn(pipeline_fn: Callable, normalize: bool = True) -> Callable:
"""Make the gradient function
Args:
pipeline_fn: Leniax pipeline function which run a full simulation and compute a loss
normalize: Set to ``True`` to normalize each gradient respectively
"""
grads_fn = jax.value_and_grad(pipeline_fn, argnums=(1), has_aux=True)
@jax.jit
def fn(rng_key: jnp.ndarray, params: Dict, variables: Dict, state0: jnp.ndarray, targets: Dict):
(loss, aux), grads = grads_fn(rng_key, params, variables, state0, targets)
if normalize is True:
grads = tree_map(normalize_fn, grads)
return (loss, aux), grads
return fn
def normalize_fn(g):
return g / (jnp.linalg.norm(g) + 1e-8)