Source code for leniax.statistics

import functools
import jax
import jax.numpy as jnp
from jax import jit
from typing import List, Dict, Callable, Tuple

from .constant import EPSILON
from .utils import center_world


[docs]def build_compute_stats_fn(world_params: Dict, render_params: Dict) -> Callable: """Construct the conpute_statistics function Args: world_params: World parameters dictrionnary. render_params: Render parameters dictrionnary. Returns: The compute statistics function """ world_size = render_params['world_size'] R = world_params['R'] dt = 1. / world_params['T'] world_dims_axes = tuple(range(-len(world_size), 0, 1)) non_batch_dims_axes = tuple(range(-(1 + len(world_size)), 0, 1)) midpoint = jnp.expand_dims( jnp.asarray([size // 2 for size in world_size]), axis=world_dims_axes ) # [nb_dims, 1, 1, ...] coords = jnp.indices(world_size) # [nb_dims, H, W, ...] # We expand our coordinate to control and take advantage of automatic broadcasting centered_coords = jnp.expand_dims(coords - midpoint, axis=(1, 2)) # [nb_dims, 1, 1, H, W, ...] @jit def compute_stats( cells: jnp.ndarray, field: jnp.ndarray, potential: jnp.ndarray, previous_total_shift_idx: jnp.ndarray, previous_mass_centroid: jnp.ndarray, previous_mass_angle: jnp.ndarray, ) -> Tuple[Dict, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Compute statistics of a Leniax simulation Args: cells: state of shape ``[N, C, world_dims...]`` field: state of shape ``[N, C, world_dims...]`` potential: state of shape ``[N, C, world_dims...]`` previous_total_shift_idx: state of shape ``[N, 2]`` previous_mass_centroid: state of shape ``[2, N]`` previous_mass_angle: state of shape ``[2, N]`` Returns: A tuple containing the statistics dictionnary and other carry informations """ # cells: # [N, C, H, W] # field: # [N, C, H, W] # potential: # [N, C, H, W] # https://en.wikipedia.org/wiki/Image_moment # To avoid weird behaviours when species are crossing the frontiers, we need to compute stats # from a centered world centered_cells, centered_field, _ = center_world(cells, field, potential, previous_total_shift_idx, world_dims_axes) positive_field = jnp.maximum(centered_field, 0) m_00 = centered_cells.sum(axis=non_batch_dims_axes) # [N] g_00 = positive_field.sum(axis=non_batch_dims_axes) # [N] potential_volume = (potential > EPSILON).sum(axis=non_batch_dims_axes) / R**2 # [N] channel_mass = centered_cells.sum(axis=world_dims_axes) / R**2 # [N, C] mass = m_00 / R**2 mass_volume = (centered_cells > EPSILON).sum(axis=non_batch_dims_axes) / R**2 # [N] mass_density = mass / (mass_volume + EPSILON) growth = g_00 / R**2 growth_volume = (positive_field > EPSILON).sum(axis=non_batch_dims_axes) / R**2 # [N] growth_density = growth / (growth_volume + EPSILON) AX = centered_cells * centered_coords # [nb_world_dims, N, C, world_dims...] MX = AX.sum(axis=non_batch_dims_axes) # [nb_world_dims, N] mass_centroid = MX / (m_00 + EPSILON) delta_mass = mass_centroid - previous_mass_centroid # [nb_world_dims, N] dist_m = jnp.linalg.norm(delta_mass, axis=0) # [N] mass_speed = dist_m / R / dt # [N] mass_angle = jnp.degrees(jnp.arctan2(delta_mass[1], delta_mass[0])) * (dist_m / R > 0.001) mass_angle_speed = ((mass_angle - previous_mass_angle + 540) % 360 - 180) / dt # [N] GX = (positive_field * centered_coords).sum(axis=non_batch_dims_axes) # [nb_world_dims, N] growth_centroid = GX / (g_00 + EPSILON) mass_growth_dist = jnp.linalg.norm(growth_centroid - mass_centroid, axis=0) / R MX2 = (AX * centered_coords).sum(axis=non_batch_dims_axes) # [nb_world_dims, N] MuX2 = MX2 - mass_centroid * MX mass2_centroid = MuX2 / (m_00**2 + EPSILON) inertia = mass2_centroid.sum(axis=0) stats = { 'channel_mass': channel_mass, 'mass': mass, 'mass_volume': mass_volume, 'mass_density': mass_density, 'growth': growth, 'growth_volume': growth_volume, 'growth_density': growth_density, 'mass_speed': mass_speed, 'mass_angle_speed': mass_angle_speed, 'mass_growth_dist': mass_growth_dist, 'inertia': inertia, 'potential_volume': potential_volume, } shift_idx = mass_centroid.astype(jnp.int32).T world_shape = jnp.array(cells.shape[2:], dtype=jnp.int32) total_shift_idx = (previous_total_shift_idx + shift_idx) % world_shape # Since we will center the world before computing stats # The mass_centroid will also be shifted, so here we make sure we avoid the # shifting errors mass_centroid = mass_centroid - mass_centroid.astype(jnp.int32) return (stats, total_shift_idx, mass_centroid, mass_angle) return compute_stats
### # Heuristics ###
[docs]@functools.partial(jit, static_argnums=(1, )) def check_heuristics(stats: Dict[str, jnp.ndarray]) -> jnp.ndarray: """Check heuristics on statistic data Args: stats: Simulation statistics dictionnary Returns: An array of boolean value indicating if the heuristics are valid for each timsteps """ def fn(carry: Dict, stat_t: Dict[str, jnp.ndarray]): should_continue = carry['should_continue'] init_channel_mass = carry['init_channel_mass'] init_mass = carry['init_mass'] previous_mass = carry['previous_mass'] previous_sign = carry['previous_sign'] counters = carry['counters'] mass = stat_t['mass'] channel_mass = stat_t['channel_mass'] # cond = min_mass_heuristic(EPSILON, mass) # should_continue_cond = cond # cond = max_mass_heuristic(init_mass, mass) # should_continue_cond *= cond cond = min_channel_mass_heuristic(EPSILON, channel_mass) should_continue_cond = cond cond = max_channel_mass_heuristic(init_channel_mass, channel_mass) should_continue_cond *= cond sign = jnp.sign(mass - previous_mass) monotone_counter = counters['nb_monotone_step'] cond, counters['nb_monotone_step'] = monotonic_heuristic(sign, previous_sign, monotone_counter) should_continue_cond *= cond mass_volume = stat_t['mass_volume'] mass_volume_counter = counters['nb_max_volume_step'] cond, counters['nb_max_volume_step'] = mass_volume_heuristic(mass_volume, mass_volume_counter) should_continue_cond *= cond should_continue *= should_continue_cond new_carry = { 'should_continue': should_continue, 'init_channel_mass': init_channel_mass, 'init_mass': init_mass, 'previous_mass': mass, 'previous_sign': sign, 'counters': counters, } return new_carry, should_continue N = stats['mass'].shape[1] init_carry = { 'should_continue': jnp.ones(N), 'init_channel_mass': stats['channel_mass'][0], 'init_mass': stats['mass'][0], 'previous_mass': stats['mass'][0], 'previous_sign': jnp.zeros(N), 'counters': init_counters(N), } _, continue_stat = jax.lax.scan(fn, init_carry, stats, unroll=1) # Used for debugging (lax.scan is a jitted function) # continue_stat = [] # all_keys = stats.keys() # for i in range(stats['mass'].shape[0]): # stat_t = {key: stats[key][i] for key in all_keys} # init_carry, should_continue = fn(init_carry, stat_t) # continue_stat.append(should_continue) return continue_stat
[docs]def init_counters(N: int) -> Dict[str, jnp.ndarray]: """Initialize different counters used in heuristics decisions Args: N: Number of simulated timesteps Returns: Adictionnary of counters """ return { 'nb_monotone_step': jnp.zeros(N, dtype=jnp.int32), 'nb_slow_mass_step': jnp.zeros(N, dtype=jnp.int32), 'nb_max_volume_step': jnp.zeros(N, dtype=jnp.int32), }
[docs]def min_channel_mass_heuristic(epsilon: float, channel_mass: jnp.ndarray) -> jnp.ndarray: """Check if a total mass per channel is below the threshold Args: epsilon: A very small value to avoid division by zero channel_mass: Total mass per channel of shape ``[N, C]`` Returns: A boolean array of shape ``[N]`` """ should_continue_cond = (channel_mass >= epsilon).all(axis=1) return should_continue_cond
[docs]def max_channel_mass_heuristic(init_channel_mass: jnp.ndarray, channel_mass: jnp.ndarray) -> jnp.ndarray: """Check if a total mass per channel is above the threshold Args: init_channel_mass: Initial mass per channel of shape ``[N, C]`` channel_mass: Total mass per channel of shape ``[N, C]`` Returns: A boolean array of shape ``[N]`` """ should_continue_cond = (channel_mass <= 3 * init_channel_mass).all(axis=1) return should_continue_cond
[docs]def min_mass_heuristic(epsilon: float, mass: jnp.ndarray) -> jnp.ndarray: """Check if the total mass of the system is below the threshold Args: epsilon: A very small value to avoid division by zero mass: Total mass of shape ``[N]`` Returns: A boolean array of shape ``[N]`` """ should_continue_cond = mass >= epsilon return should_continue_cond
[docs]def max_mass_heuristic(init_mass: jnp.ndarray, mass: jnp.ndarray) -> jnp.ndarray: """Check if a total mass per channel is above the threshold Args: init_mass: Initial mass per channel of shape ``[N]`` mass: Total mass per channel of shape ``[N]`` Returns: A boolean array of shape ``[N]`` """ should_continue_cond = mass <= 3 * init_mass return should_continue_cond
MONOTONIC_STOP_STEP = 128
[docs]def monotonic_heuristic( sign: jnp.ndarray, previous_sign: jnp.ndarray, monotone_counter: jnp.ndarray, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Check if the mass variation is being monotonic for too many timesteps Args: sign: Current sign of mass variation of shape ``[N]`` previous_sign: Previous sign of mass variation of shape ``[N]`` monotone_counter: Counter used to count number of timesteps with monotonic variations of shape ``[N]`` Returns: A tuple representing a boolean array of shape ``[N]`` and the counter """ sign_cond = (sign == previous_sign) monotone_counter = monotone_counter * sign_cond + 1 should_continue_cond = monotone_counter <= MONOTONIC_STOP_STEP return should_continue_cond, monotone_counter
# Those are 3200 pixels activated # This number comes from the early days when I set the threshold # at 10% of a 128*128 ~= 1600 pixels with a kernel radius of 13 # which gives a volume threshold of 1600 / 13^2 ~= 9.5 MASS_VOLUME_THRESHOLD = 10. MASS_VOLUME_STOP_STEP = 128
[docs]def mass_volume_heuristic(mass_volume: jnp.ndarray, mass_volume_counter: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """Check if the mass volume is above the threshold for too manye timesteps Args: mass_volume: Mass volume of shape ``[N]`` mass_volume_counter: Counter of shape ``[N]`` used to count number of timesteps with a volume above the threshold Returns: A tuple representing a boolean array of shape ``[N]`` and the counter """ volume_cond = jnp.array(mass_volume > MASS_VOLUME_THRESHOLD) mass_volume_counter = mass_volume_counter * volume_cond + 1 should_continue_cond = mass_volume_counter <= MASS_VOLUME_STOP_STEP return should_continue_cond, mass_volume_counter
### # Utils ###
[docs]def stats_list_to_dict(all_stats: List[Dict]) -> Dict[str, jnp.ndarray]: """Change a list of dictionnary in a dictionnary of array Args: all_stats: List of 1-timestep statistics dictionary Returns: A dictionnary of N-timestep array. """ if len(all_stats) == 0: return {} stats_dict_list: Dict[str, List[float]] = {} all_keys = list(all_stats[0].keys()) for k in all_keys: stats_dict_list[k] = [] for stat in all_stats: for k, v in stat.items(): stats_dict_list[k].append(v) stats_dict = {k: jnp.array(stats_dict_list[k]) for k in all_keys} return stats_dict