Source code for leniax.core

"""Leniax core simulation functions
"""

import functools
import jax
from jax import lax, jit
import jax.numpy as jnp
from typing import Callable, Tuple, Optional

GetStateCallableType = Callable[[jax.random.KeyArray, jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]


[docs]@functools.partial(jit, static_argnums=(6, 7, 8)) def update( rng_key: jax.random.KeyArray, state: jnp.ndarray, K: jnp.ndarray, gf_params: jnp.ndarray, kernels_weight_per_channel: jnp.ndarray, dt: jnp.ndarray, get_potential_fn: Callable, get_field_fn: Callable, get_state_fn: GetStateCallableType, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Update the cells state Jitted function with static argnums. Use functools.partial to set the different function. Avoid changing non-static argument shape for performance. Args: rng_key: JAX PRNG key. state: cells state ``[N, nb_channels, world_dims...]`` K: Kernel ``[K_o=nb_channels * max_k_per_channel, K_i=1, kernel_dims...]`` gf_params: Growth function parmaeters ``[nb_kernels, params_shape...]`` kernels_weight_per_channel: Kernels weight used in the averaginf function ``[nb_channels, nb_kernels]`` dt: Update rate ``[N]`` get_potential_fn: **(jit static arg)** Function used to compute the potential get_field_fn: **(jit static arg)** Function used to compute the field get_state_fn: **(jit static arg)** Function used to compute the new cell state Returns: A tuple of arrays representing a jax PRNG key and the updated state, the used potential and used field. """ potential = get_potential_fn(state, K) field = get_field_fn(potential, gf_params, kernels_weight_per_channel) state = get_state_fn(rng_key, state, field, dt) return state, field, potential
[docs]@functools.partial(jit, static_argnums=(2, 3)) def get_potential_fft( state: jnp.ndarray, K: jnp.ndarray, tc_indices: Optional[Tuple] = None, channel_first: bool = True, ) -> jnp.ndarray: """Compute the potential using FFT The first dimension of cells and K is the vmap dimension Args: state: cells state ``[N, nb_channels, world_dims...]`` K: Kernels ``[1, nb_channels, max_k_per_channel, world_dims...]`` tc_indices: Optional ``1-dim`` array channels indices to keep channel_first: see ``https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html`` Returns: An array containing the potential """ dim_numbers = _conv_dimension_numbers(state.shape, channel_first) channel_dim = dim_numbers.lhs_spec[1] C = state.shape[channel_dim] if channel_first is True: world_dims = tuple(range(-1, -len(state.shape) + 1, -1)) else: world_dims = tuple(range(-2, -len(state.shape), -1)) fft_cells = jnp.fft.fftn(state, axes=world_dims) if channel_first is True: fft_cells = fft_cells[:, :, jnp.newaxis] # [N, nb_channels, 1, world_dims...] else: fft_cells = fft_cells[..., jnp.newaxis] # [N, world_dims..., nb_channels, 1] fft_out = fft_cells * K conv_out = jnp.real(jnp.fft.ifftn(fft_out, axes=world_dims)) if channel_first is True: max_k_per_channel = K.shape[2] conv_shape = (-1, max_k_per_channel * C) + state.shape[2:] else: max_k_per_channel = K.shape[-1] conv_shape = state.shape[0:-1] + (max_k_per_channel * C, ) conv_out_reshaped = conv_out.reshape(conv_shape) if tc_indices is not None: potential = jnp.take(conv_out_reshaped, jnp.array(tc_indices), axis=channel_dim) else: potential = conv_out_reshaped return potential
[docs]@functools.partial(jit, static_argnums=(2, 3, 4)) def get_potential( state: jnp.ndarray, K: jnp.ndarray, padding: Tuple, tc_indices: Optional[Tuple] = None, channel_first: bool = True, ) -> jnp.ndarray: """Compute the potential using lax.conv_general_dilated The first dimension of cells and K is the vmap dimension Args: state: cells state ``[N, nb_channels, world_dims...]`` K: Kernels ``[K_o=nb_channels * max_k_per_channel, K_i=1, kernel_dims...]`` padding: array with padding informations, ``[nb_world_dims, 2]`` tc_indices: Optional ``1-dim`` array channels indices to keep channel_first: see ``https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html`` Returns: An array containing the potential """ dim_numbers = _conv_dimension_numbers(state.shape, channel_first) channel_dim = dim_numbers.lhs_spec[1] C = state.shape[channel_dim] padded_state = jnp.pad(state, padding, mode='wrap') conv_out_reshaped = lax.conv_general_dilated( padded_state, K, (1, 1), 'VALID', feature_group_count=C, dimension_numbers=dim_numbers, ) if tc_indices is not None: potential = jnp.take(conv_out_reshaped, jnp.array(tc_indices), axis=channel_dim) else: potential = conv_out_reshaped return potential
def _conv_dimension_numbers(input_shape, channel_first=True): """Computes the dimension numbers based on the input shape.""" ndim = len(input_shape) if channel_first is True: lhs_spec = tuple(range(0, ndim)) rhs_spec = tuple(range(0, ndim)) else: lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) out_spec = lhs_spec return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
[docs]@functools.partial(jit, static_argnums=(3, 4)) def get_field( potential: jnp.ndarray, gf_params: jnp.ndarray, kernels_weight_per_channel: jnp.ndarray, growth_fn_t: Tuple[Callable, ...], weighted_fn: Callable ) -> jnp.ndarray: """Compute the field Jitted function with static argnums. Use functools.partial to set the different function. Avoid changing non-static argument shape for performance. Args: potential: ``[N, nb_kernels, world_dims...]`` gf_params: ``[nb_kernels, nb_gf_params]`` kernels_weight_per_channel: Kernels weight used in the averaginf function ``[nb_channels, nb_kernels]`` growth_fn_t: **(jit static arg)** Tuple of growth functions. ``length: nb_kernels`` weighted_fn: **(jit static arg)** Function used to merge fields linked to the same channel Returns: An array containing the field """ fields = [] for i in range(len(growth_fn_t)): sub_potential = potential[:, i] current_gf_params = gf_params[i] growth_fn = growth_fn_t[i] sub_field = growth_fn(current_gf_params, sub_potential) fields.append(sub_field) fields_jnp = jnp.stack(fields, axis=1) # [N, nb_kernels, world_dims...] fields_jnp = weighted_fn(fields_jnp, kernels_weight_per_channel) # [N, C, H, W] return fields_jnp
[docs]def weighted_sum(fields: jnp.ndarray, weights: jnp.ndarray) -> jnp.ndarray: """Compute the weighted sum of sub fields Args: fields: Raw sub fields ``[N, nb_kernels, world_dims...]`` weights: Weights used to compute the sum ``[nb_channels, nb_kernels]`` 0. values are used to indicate that a given channels does not receive inputs from this kernel Returns: The unnormalized field """ N = fields.shape[0] nb_kernels = fields.shape[1] world_dims = list(fields.shape[2:]) nb_channels = weights.shape[0] fields_swapped = fields.swapaxes(0, 1) # [nb_kernels, N, world_dims...] fields_reshaped = fields_swapped.reshape(nb_kernels, -1) out_tmp = jnp.matmul(weights, fields_reshaped) out_tmp = out_tmp.reshape([nb_channels, N] + world_dims) fields_out = out_tmp.swapaxes(0, 1) # [N, nb_channels, world_dims...] return fields_out
[docs]def weighted_mean(fields: jnp.ndarray, weights: jnp.ndarray) -> jnp.ndarray: """Compute the weighted mean of sub fields Args: fields: Raw sub fields ``[N, nb_kernels, world_dims...]`` weights: Weights used to compute the sum ``[nb_channels, nb_kernels]`` 0. values are used to indicate that a given channels does not receive inputs from this kernel Returns: The normalized field """ fields_out = weighted_sum(fields, weights) fields_normalized = fields_out / weights.sum(axis=1)[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] return fields_normalized
def get_state( rng_key: jax.random.KeyArray, state: jnp.ndarray, field: jnp.ndarray, dt: jnp.ndarray, ) -> jnp.ndarray: """Compute the new cells state using the original Lenia formula Args: rng_key: JAX PRNG key. state: Current cells state ``[N, nb_channels, world_dims...]`` field: Current field``[N, nb_channels, world_dims...]`` dt: Update rate ``[N]`` Returns: The new cells state Reference: https://arxiv.org/abs/1812.05433 """ new_state = state + dt * field # Straight-through estimator clipped_cells = jnp.clip(new_state, 0., 1.) zero = new_state - lax.stop_gradient(new_state) out = zero + lax.stop_gradient(clipped_cells) return out def get_state_v2( rng_key: jax.random.KeyArray, state: jnp.ndarray, field: jnp.ndarray, dt: jnp.ndarray, ) -> jnp.ndarray: """Compute the new cells state using the asymptotic Lenia formula Args: rng_key: JAX PRNG key. state: Current cells state ``[N, nb_channels, world_dims...]`` field: Current field``[N, nb_channels, world_dims...]`` dt: Update rate ``[N]`` Returns: The new cells state Reference: https://direct.mit.edu/isal/proceedings/isal/91/102916 """ new_state = state * (1 - dt) + dt * field return new_state def get_state_simple( rng_key: jax.random.KeyArray, state: jnp.ndarray, field: jnp.ndarray, dt: jnp.ndarray, ) -> jnp.ndarray: """Compute the new cells state by simply adding the field directly Args: rng_key: JAX PRNG key. state: Current cells state ``[N, nb_channels, world_dims...]`` field: Current field``[N, nb_channels, world_dims...]`` dt: Update rate ``[N]`` **(not used)** Returns: The new cells state """ new_state = state + dt * field return new_state register = { 'v1': get_state, 'v2': get_state_v2, 'simple': get_state_simple, }