"""Leniax helper functions
Those functions are provided to ease the use of this library. See them as template
gluing core functions together to achieve common usages.
"""
import os
import copy
import logging
import time
import json
import pickle
import functools
import numpy as np
import scipy
import jax
import jax.numpy as jnp
from typing import Callable, Dict, Tuple, List, Union, Optional
import matplotlib.pyplot as plt
from . import core as leniax_core
from . import runner as leniax_runner
from . import initializations as leniax_init
from . import statistics as leniax_stat
from . import loader as leniax_loader
from . import utils as leniax_utils
from . import video as leniax_video
from . import colormaps as leniax_colormaps
from . import kernels as leniax_kernels
from .kernel_functions import register as kf_register
from .growth_functions import register as gf_register
cdir = os.path.dirname(os.path.realpath(__file__))
[docs]def init(
config: Dict,
use_init_cells: bool = True,
fft: bool = True,
) -> Tuple[jnp.ndarray, jnp.ndarray, leniax_kernels.KernelMapping]:
"""Construct the initial state and metadata to run a simulation.
Args:
config: Leniax confguration
use_init_cells: Set to ``True`` to use the ``init_cells`` configuration property.
fft: Set to ``True`` to use FFT optimization
Returns:
A 3-tuple representing the initial state, Kernel and mapping data
"""
nb_dims = config['world_params']['nb_dims']
nb_channels = config['world_params']['nb_channels']
world_size = config['render_params']['world_size']
kernels_params = config['kernels_params']
assert len(world_size) == nb_dims
assert nb_channels > 0
raw_cells = leniax_loader.load_raw_cells(config, use_init_cells)
scale = config['world_params']['scale']
if scale != 1.:
raw_cells = jnp.array(
[scipy.ndimage.zoom(raw_cells[i], scale, order=0) for i in range(nb_channels)],
dtype=jnp.float32,
)
# We update the configuration here, the new R value will be used in the statistics
config['world_params']['R'] *= scale
# assert cells.shape[1] * 2.2 < config['render_params']['world_size'][0]
# assert cells.shape[2] * 2.2 < config['render_params']['world_size'][1]
# on = round(max(cells.shape[1] * 1.25, cells.shape[2] * 1.25) / 2.)
# init_cells = create_init_cells(world_size, nb_channels, [
# jnp.rot90(cells, k=2, axes=(1, 2)),
# jnp.rot90(cells, k=1, axes=(1, 2)),
# jnp.rot90(cells, k=0, axes=(1, 2)),
# jnp.rot90(cells, k=-1, axes=(1, 2)),
# ], [
# [0, -on, -on],
# [0, -on, on],
# [0, on, on],
# [0, on, -on],
# ])
if len(raw_cells.shape) > 1 + nb_dims:
init_cells = create_init_cells(world_size, nb_channels, raw_cells)
else:
init_cells = create_init_cells(world_size, nb_channels, [raw_cells])
K, mapping = leniax_kernels.get_kernels_and_mapping(kernels_params, world_size, nb_channels, config['world_params']['R'], fft)
return init_cells, K, mapping
[docs]def create_init_cells(
world_size: List[int],
nb_channels: int,
other_cells: Union[jnp.ndarray, List[jnp.ndarray]] = [],
offsets: List[List[int]] = [],
) -> jnp.ndarray:
"""Construct the initial state
Args:
world_size: World size
nb_channels: Number of world channels
other_cells: Other initial states to merge
offsets: Offsets used to merge other initial states
Returns:
The initial state
"""
world_shape = [nb_channels] + world_size # [C, world_dims...]
cells = jnp.zeros(world_shape)
if isinstance(other_cells, list):
if len(offsets) == len(other_cells):
for c, offset in zip(other_cells, offsets):
if len(c) != 0:
cells = leniax_utils.merge_cells(cells, c, offset)
else:
for c in other_cells:
if len(c) != 0:
cells = leniax_utils.merge_cells(cells, c)
cells = cells[jnp.newaxis, ...]
elif isinstance(other_cells, jnp.ndarray) or isinstance(other_cells, np.ndarray):
if isinstance(other_cells, np.ndarray):
other_cells = jnp.array(other_cells, dtype=jnp.float32)
cells = other_cells
else:
raise ValueError(f'Don\'t know how to handle {type(other_cells)}')
return cells
[docs]def init_and_run(
rng_key: jax.random.KeyArray,
config: Dict,
use_init_cells: bool = True,
with_jit: bool = True,
fft: bool = True,
stat_trunc: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict]:
"""Initialize and simulate a Lenia configuration
To simulate a configuration with multiple initializations you must set:
- ``with_jit=True`` so the function use the scan implementaton.
- ``stat_trunc=False`` multiple initializations means different simulation length measured by the statistics.
Args:
rng_key: JAX PRNG key.
config: Lenia configuration
use_init_cells: Set to ``True`` to use the ``init_cells`` configuration property.
with_jit: Set to ``True`` to use the jitted scan implementation
fft: Set to ``True`` to use FFT optimization
stat_trunc: Set to ``True`` to truncate run based on its statistics
Returns:
A tuple of `[nb_iter, nb_init, nb_channels, world_dims...]` shaped
cells, fields, potentials and statistics of the simulation.
"""
config = copy.deepcopy(config)
cells, K, mapping = init(config, use_init_cells, fft)
gf_params = mapping.get_gf_params()
kernels_weight_per_channel = mapping.get_kernels_weight_per_channel()
world_params = config['world_params']
get_state_fn_slug = world_params['get_state_fn_slug'] if 'get_state_fn_slug' in world_params else 'v1'
weighted_average = world_params['weighted_average'] if 'weighted_average' in world_params else True
R = config['world_params']['R']
T = jnp.array(config['world_params']['T'], dtype=jnp.float32)
max_run_iter = config['run_params']['max_run_iter']
update_fn = build_update_fn(K.shape, mapping, get_state_fn_slug, weighted_average, fft)
compute_stats_fn = leniax_stat.build_compute_stats_fn(config['world_params'], config['render_params'])
if with_jit is True:
all_cells, all_fields, all_potentials, stats_dict = leniax_runner.run_scan(
rng_key, cells, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn
)
else:
all_cells, all_fields, all_potentials, stats_dict = leniax_runner.run(
rng_key, cells, K, gf_params, kernels_weight_per_channel, T, max_run_iter, R, update_fn, compute_stats_fn, stat_trunc
) # [nb_iter, nb_init, C, world_dims...]
stats_dict = {k: v.squeeze() for k, v in stats_dict.items()}
if stat_trunc is True:
all_cells = all_cells[:int(stats_dict['N'])] # [nb_iter, C, world_dims...]
all_fields = all_fields[:int(stats_dict['N'])]
all_potentials = all_potentials[:int(stats_dict['N'])]
return all_cells, all_fields, all_potentials, stats_dict
def multi_init_and_run(rng_key, main_config, configs, use_init_cells, fft):
main_config = copy.deepcopy(main_config)
world_params = main_config['world_params']
get_state_fn_slug = world_params['get_state_fn_slug'] if 'get_state_fn_slug' in world_params else 'v1'
weighted_average = world_params['weighted_average'] if 'weighted_average' in world_params else True
R = main_config['world_params']['R']
max_run_iter = main_config['run_params']['max_run_iter']
_, K, mapping = init(main_config, use_init_cells, fft)
update_fn = build_update_fn(K.shape, mapping, get_state_fn_slug, weighted_average, fft)
compute_stats_fn = leniax_stat.build_compute_stats_fn(world_params, main_config['render_params'])
multi_run_fn = jax.vmap(leniax_runner.run_scan, in_axes=(None, 0, 0, 0, 0, 0, None, None, None, None), out_axes=0)
multi_cells = []
multi_K = []
multi_gf_params = []
multi_kernels_weight_per_channel = []
multi_T = []
for config in configs:
cells, K, mapping = init(config, use_init_cells, fft)
gf_params = mapping.get_gf_params()
kernels_weight_per_channel = mapping.get_kernels_weight_per_channel()
T = jnp.array(config['world_params']['T'], dtype=jnp.float32)
multi_cells.append(cells)
multi_K.append(K)
multi_gf_params.append(gf_params)
multi_kernels_weight_per_channel.append(kernels_weight_per_channel)
multi_T.append(T)
all_cells, all_fields, all_potentials, stats_dict = multi_run_fn(
rng_key,
jnp.array(multi_cells),
jnp.array(multi_K),
jnp.array(multi_gf_params),
jnp.array(multi_kernels_weight_per_channel),
jnp.array(multi_T),
max_run_iter,
R,
update_fn,
compute_stats_fn
)
stats_dict = {k: v.squeeze() for k, v in stats_dict.items()}
return all_cells, all_fields, all_potentials, stats_dict
[docs]def search_for_mutation(
rng_key: jax.random.KeyArray,
config: Dict,
nb_scale_for_stability: int = 1,
use_init_cells: bool = True,
fft: bool = True,
mutation_rate: float = 1e-5,
) -> Tuple[Dict, int]:
"""Search for a stable mutation
Args:
rng_key: JAX PRNG key.
config: Lenia configuration.
use_init_cells: Set to ``True`` to use the ``init_cells`` configuration property.
fft: Set to ``True`` to use FFT optimization.
nb_scale_for_stability: Number of time the configuration will be scaled and tested.
mutation_rate: Mutation rate.
Returns:
A 2-tuple of a dictionnary with the best run data and
the number of runs made to find it
"""
render_params = config['render_params']
world_size = render_params['world_size']
run_params = config['run_params']
nb_mut_search = run_params['nb_mut_search']
max_run_iter = run_params['max_run_iter']
best_run = {}
current_max = 0
nb_genes = len(config['genotype'])
rng_key, *subkeys = jax.random.split(rng_key, nb_mut_search * nb_genes + 1)
for i in range(nb_mut_search):
copied_config = copy.deepcopy(config)
for gene_i, gene in enumerate(config['genotype']):
val = leniax_utils.get_param(copied_config, gene['key'])
subkey = subkeys[i * nb_genes + gene_i]
val += jax.random.normal(subkey, dtype=jnp.float32) * mutation_rate
leniax_utils.set_param(copied_config, gene['key'], float(val))
total_iter_done = 0
nb_iter_done = 0
for scale_power in range(nb_scale_for_stability):
# t0 = time.time()
scaled_config = copy.deepcopy(copied_config)
scaled_config['render_params']['world_size'] = [ws * 2**scale_power for ws in world_size]
scaled_config['world_params']['scale'] = 2**scale_power
# We do not split the rng_key here, because we want to keep the exact same run
# while just changing the rendering size
all_cells, _, _, stats_dict = init_and_run(
rng_key, scaled_config, use_init_cells=use_init_cells, with_jit=True, fft=fft
)
all_cells = all_cells[:, 0]
stats_dict['N'].block_until_ready()
nb_iter_done = max(nb_iter_done, stats_dict['N'])
total_iter_done += stats_dict['N']
# Current_max at all scale
if current_max < total_iter_done:
current_max = total_iter_done
best_run = {
"N": nb_iter_done,
"all_cells": all_cells,
"all_stats": stats_dict,
"config": copied_config,
}
if total_iter_done >= max_run_iter * nb_scale_for_stability:
break
return best_run, i
[docs]def search_for_init(
rng_key: jax.random.KeyArray,
config: Dict,
fft: bool = True,
) -> Tuple[Dict, int]:
"""Search for a stable initial state
Args:
rng_key: JAX PRNG key.
config: Lenia configuration.
fft: Set to ``True`` to use FFT optimization.
Returns:
A 2-tuple of a dictionnary with the best run data and
the number of runs made to find it
"""
world_params = config['world_params']
nb_channels = world_params['nb_channels']
get_state_fn_slug = world_params['get_state_fn_slug'] if 'get_state_fn_slug' in world_params else 'v1'
weighted_average = world_params['weighted_average'] if 'weighted_average' in world_params else True
R = world_params['R']
T = jnp.array(world_params['T'], dtype=jnp.float32)
render_params = config['render_params']
world_size = render_params['world_size']
kernels_params = config['kernels_params']
run_params = config['run_params']
nb_init_search = run_params['nb_init_search']
max_run_iter = run_params['max_run_iter']
init_slug = config['algo']['init_slug']
K, mapping = leniax_kernels.get_kernels_and_mapping(kernels_params, world_size, nb_channels, R, fft)
gf_params = mapping.get_gf_params()
kernels_weight_per_channel = mapping.get_kernels_weight_per_channel()
update_fn = build_update_fn(K.shape, mapping, get_state_fn_slug, weighted_average, fft)
compute_stats_fn = leniax_stat.build_compute_stats_fn(config['world_params'], config['render_params'])
nb_channels_to_init = nb_channels * nb_init_search
rng_key, noises = leniax_init.register[init_slug](
rng_key, nb_channels_to_init, world_size, R, kernels_params[0]['gf_params']
)
init_noises = noises.reshape([nb_init_search, nb_channels] + world_size)
all_cells0_l = [create_init_cells(world_size, nb_channels, [init_noises[i]]) for i in range(nb_init_search)]
all_cells0_jnp = jnp.array(all_cells0_l, dtype=jnp.float32)
best_run = {}
current_max = 0
subkeys = jax.random.split(rng_key, nb_init_search)
for i in range(nb_init_search):
all_cells, _, _, all_stats = leniax_runner.run_scan(
subkeys[i],
all_cells0_jnp[i],
K,
gf_params,
kernels_weight_per_channel,
T,
max_run_iter,
R,
update_fn,
compute_stats_fn
)
# https://jax.readthedocs.io/en/latest/async_dispatch.html
all_stats['N'].block_until_ready()
nb_iter_done = all_stats['N']
if current_max < nb_iter_done:
current_max = nb_iter_done
best_run = {"N": nb_iter_done, "all_cells": all_cells, "all_stats": all_stats}
if nb_iter_done >= max_run_iter:
break
return best_run, i
###
# Pipeline dynamic builder
###
[docs]def build_update_fn(
kernel_shape: Tuple[int, ...],
mapping: leniax_kernels.KernelMapping,
get_state_fn_slug: str = 'v1',
average_weight: bool = True,
fft: bool = True,
) -> Callable:
"""Construct an Leniax update function
An update function allows one to update a simulation state.
Args:
kernel_shape: Kernel shape.
mapping: Mapping data.
get_state_fn_slug: Which version of Lenia should be run
fft: Set to ``True`` to use FFT optimization
Returns:
A Leniax update function
"""
get_potential_fn = build_get_potential_fn(kernel_shape, mapping.true_channels, fft)
get_field_fn = build_get_field_fn(mapping.cin_gfs, average_weight)
get_state_fn = leniax_core.register[get_state_fn_slug]
return functools.partial(
leniax_core.update, get_potential_fn=get_potential_fn, get_field_fn=get_field_fn, get_state_fn=get_state_fn
)
[docs]def build_get_potential_fn(
kernel_shape: Tuple[int, ...],
true_channels: Optional[List[bool]] = None,
fft: bool = True,
channel_first: bool = True,
) -> Callable:
"""Construct an Leniax potential function
A potential function allows one to compute the potential from a Lenia state.
Args:
kernel_shape: Kernel shape.
true_channels: Boolean array indicating the true potential channels
fft: Set to ``True`` to use FFT optimization
Returns:
A Leniax potential function
"""
# First 2 dimensions are for fake batch dim and nb_channels
if true_channels is not None:
tc_indices_l = []
for i in range(len(true_channels)):
if true_channels[i] is True:
tc_indices_l.append(i)
tc_indices = tuple(tc_indices_l)
else:
tc_indices = None
if fft is True:
return functools.partial(
leniax_core.get_potential_fft,
tc_indices=tc_indices,
channel_first=channel_first,
)
else:
if channel_first is True:
pad_l = [(0, 0), (0, 0)]
for dim in kernel_shape[2:]:
if dim % 2 == 0:
pad_l += [(dim // 2, dim // 2 - 1)]
else:
pad_l += [(dim // 2, dim // 2)]
else:
pad_l = [(0, 0)]
for dim in kernel_shape[:2]:
if dim % 2 == 0:
pad_l += [(dim // 2, dim // 2 - 1)]
else:
pad_l += [(dim // 2, dim // 2)]
pad_l += [(0, 0)]
padding = tuple(pad_l)
return functools.partial(
leniax_core.get_potential,
tc_indices=tc_indices,
padding=padding,
channel_first=channel_first,
)
[docs]def build_get_field_fn(cin_gfs: List[List[str]], average: bool = True) -> Callable:
"""Construct an Leniax field function
A field function allows one to compute the field from a Lenia potential.
Args:
cin_gfs: List of growth functions per channel.
average: Set to ``True`` to average instead of summing input channels
Returns:
A Leniax field function
"""
growth_fn_l = []
for growth_fns_per_channel in cin_gfs:
for gf_slug in growth_fns_per_channel:
growth_fn = gf_register[gf_slug]
growth_fn_l.append(growth_fn)
growth_fn_t = tuple(growth_fn_l)
if average:
weighted_fn = leniax_core.weighted_mean
else:
weighted_fn = leniax_core.weighted_sum
return functools.partial(leniax_core.get_field, growth_fn_t=growth_fn_t, weighted_fn=weighted_fn)
###
# Viz
###
[docs]def dump_assets(
save_dir: str,
config: Dict,
all_cells: jnp.ndarray,
stats_dict: Dict,
colormaps: List = [],
transparent_bg: bool = False
):
"""Dump a set of interesting assets.
Those assets include:
- Simulation statistics (plots and data)
- Kernels and growth functions plots
- Last frame
- Video and Gif of the simulation
Args:
save_dir: directory used to save assets.
config: Leniax configuration.
all_cells: Simulation data of shape ``[nb_iter, C, world_dims...]``.
stats_dict: Leniax statistics dictionnary.
colormaps: A List of matplotlib compatible colormap.
transparent_bg: Set to ``True`` to make the background transparent.
"""
if len(colormaps) == 0:
colormaps = [
leniax_colormaps.ExtendedColormap('extended')
] # https://matplotlib.org/stable/tutorials/colors/colormaps.html
leniax_utils.plot_stats(save_dir, stats_dict)
plot_kernels(save_dir, config)
with open(os.path.join(save_dir, 'stats_dict.p'), 'wb') as f:
pickle.dump(stats_dict, f)
# with open(os.path.join(save_dir, 'cells.p'), 'wb') as f:
# np.save(f, np.array(all_cells))
dump_last_frame(save_dir, all_cells, False, colormaps[0])
dump_viz_data(save_dir, config, stats_dict)
logging.info("Video rendering: start.")
start_time = time.time()
all_outputs_fullpath = leniax_video.render_video(
save_dir, all_cells, config['render_params'], colormaps, '', transparent_bg
)
total_time = time.time() - start_time
logging.info(
f"Video rendering: stop. {len(all_cells)} states computed in {total_time:.2f} seconds, {len(all_cells) / total_time:.2f} fps."
)
logging.info("Gif rendering: start.")
start_time = time.time()
for output_fullpath in all_outputs_fullpath:
leniax_video.render_gif(output_fullpath)
total_time = time.time() - start_time
logging.info(
f"Gif rendering: stop. {len(all_cells)} states computed in {total_time:.2f} seconds, {len(all_cells) / total_time:.2f} fps."
)
[docs]def dump_last_frame(
save_dir: str,
all_cells: jnp.ndarray,
center_and_crop: bool = True,
colormap=None,
):
"""Dump the last frame of the simulation
The dumped last frame is called **last_frame.png**.
Args:
save_dir: directory used to save assets.
all_cells: Simulation data of shape ``[nb_iter, C, world_dims...]``.
center_and_crop: Set to ``True`` to center the stable pattern and crop the margin.
colormap: A matplotlib compatible colormap.
"""
last_frame = all_cells[-1]
dump_frame(save_dir, 'last_frame', last_frame, center_and_crop, colormap)
[docs]def dump_frame(
save_dir: str,
filename: str,
cells: jnp.ndarray,
center_and_crop: bool = True,
colormap=None,
):
"""Dump a Lenia state as a image
Args:
save_dir: directory used to save assets.
filename: File name.
cells: A Lenia state of shape ``[C, world_dims...]``.
center_and_crop: Set to ``True`` to center the stable pattern and crop the margin.
colormap: A matplotlib compatible colormap.
"""
if center_and_crop is True:
cells = leniax_utils.center_and_crop(cells)
if colormap is None:
colormap = plt.get_cmap('plasma')
# with open(os.path.join(save_dir, f"{filename}.p"), 'wb') as f:
# pickle.dump(np.array(cells), f)
img = leniax_utils.get_image(np.array(cells), 1, colormap)
with open(os.path.join(save_dir, f"{filename}.png"), 'wb') as f:
img.save(f, format='png')
[docs]def dump_viz_data(save_dir: str, config: Dict, stats_dict: Dict):
"""Dump vizualization data as JSON
Args:
save_dir: directory used to save assets.
config: Leniax configuration.
stats_dict: Leniax statistics dictionnary.
"""
viz_data: Dict = {'stats': {}}
for k, v in stats_dict.items():
if k == 'N':
continue
truncated_stat = v[:int(stats_dict['N'])]
viz_data['stats'][k + '_mean'] = round(float(truncated_stat[-128:].mean()), 5)
viz_data['stats'][k + '_std'] = round(float(truncated_stat[-128:].std()), 5)
viz_data['img_url'] = 'last_frame.png'
viz_data['k'] = config['kernels_params']
viz_data['R'] = config['world_params']['R']
viz_data['T'] = config['world_params']['T']
with open(os.path.join(save_dir, 'viz_data.json'), 'w') as fviz:
json.dump(viz_data, fviz)
[docs]def plot_kernels(save_dir: str, config: Dict):
"""Plots kernels and growth functions
Args:
save_dir: directory used to save assets.
config: Leniax configuration.
"""
R = config['world_params']['R']
scale = config['world_params']['scale']
C = config['world_params']['nb_channels']
x = jnp.linspace(0, 1, 1000)
Ks, _ = leniax_kernels.get_kernels_and_mapping(
config['kernels_params'], config['render_params']['world_size'], C, R, fft=False
)
Ks = Ks[:, 0]
nb_Ks = Ks.shape[0]
all_kfs = []
all_gfs = []
for param in config['kernels_params']:
if param['kf_slug'] != '':
all_kfs.append(kf_register[param['kf_slug']](param['kf_params'], x))
all_gfs.append(gf_register[param['gf_slug']](param["gf_params"], x) * param['h'])
# Plot kernels image where color represent intensity
rows = int(nb_Ks**0.5)
cols = nb_Ks // rows
if cols * rows < nb_Ks:
cols += 1
axes = []
fig = plt.figure(figsize=(6, 6))
# We assume the kernel has the same size on all dimensions
K_size = Ks.shape[-1]
K_mid = K_size // 2
vmax = Ks.max()
vmin = Ks.min()
fullpath = f"{save_dir}/Ks.png"
if len(Ks.shape) == 3:
# 2D kernels
for i in range(nb_Ks):
axes.append(fig.add_subplot(rows, cols, i + 1))
axes[-1].title.set_text(f"kernel K{i}")
plt.imshow(Ks[i], cmap='viridis', interpolation="nearest", vmin=vmin, vmax=vmax)
elif len(Ks.shape) == 4:
# 3D kernels
for i in range(nb_Ks):
K_to_plot = np.array(Ks[i] > 0) # Need a real numpy array to work with the voxels function
K_to_plot[K_mid:] = 0
K_to_plot[:, :, K_mid:] = 0
colors = plt.get_cmap('viridis')(Ks[i] / vmax)
colors[:, :, :, 3] = 0.5
axes.append(fig.add_subplot(rows, cols, i + 1, projection='3d'))
axes[-1].title.set_text(f"kernel K{i}")
axes[-1].voxels(K_to_plot, facecolors=colors)
else:
raise ValueError('We do not support plotting kernels containing more than 3 dimensions')
plt.tight_layout()
fig.savefig(fullpath)
plt.close(fig)
# Plot Kernels and growth functions
fullpath = f"{save_dir}/Ks_graph.png"
fig, ax = plt.subplots(1, 3, figsize=(10, 2))
if len(all_kfs) > 0:
ax[0].plot(x, jnp.asarray(all_kfs).T)
ax[0].axhline(y=0, color='grey', linestyle='dotted')
ax[0].title.set_text('Kernel functions')
if len(Ks.shape) == 3:
ax[1].plot(range(K_size), Ks[:, K_mid, :].T)
elif len(Ks.shape) == 4:
ax[1].plot(range(K_size), Ks[:, K_mid, K_mid, :].T)
ax[1].title.set_text('Ks cross-sections')
ax[1].set_xlim([K_mid - 3 - R * scale, K_mid + 3 + R * scale])
ax[2].plot(x, jnp.asarray(all_gfs).T)
ax[2].axhline(y=0, color='grey', linestyle='dotted')
ax[2].title.set_text('Growth functions')
plt.tight_layout()
fig.savefig(fullpath)
plt.close(fig)