Source code for leniax.kernels

import math
from dataclasses import dataclass
from typing import List, Dict, Tuple, Callable, Optional
import jax.numpy as jnp

from . import utils as leniax_utils
from .kernel_functions import register as kf_register


[docs]@dataclass() class KernelMapping(object): """Explicit mapping of the computation graph Attributes: cin_kernels: list of kernel indexes grouped by channel of shape ``[nb_channels, max_nb_kernels]`` cin_k_params: list of kernel parameters grouped by channel of shape ``[nb_channels, max_nb_kernels]`` cin_kfs: list of kernel functions grouped by channel of shape ``[nb_channels, max_nb_kernels]`` cin_gfs: list of growth function slugs grouped by channel of shape ``[nb_channels, max_nb_kernels]`` gf_params: list of growth function slug grouped by channel of shape ``[nb_channels, max_nb_kernels]`` kernels_weight_per_channel: list of weights grouped by channel of shape ``[nb_channels, nb_kernels]`` """ cin_kernels: List[List[int]] cin_k_params: List[List] cin_kfs: List[List[str]] cin_gfs: List[List[str]] cin_gf_params: List[List] kernels_weight_per_channel: List[List[float]] true_channels: Optional[List[bool]] = None def __init__(self, nb_channels: int, nb_kernels: int): self.cin_kernels = [[] for _ in range(nb_channels)] self.cin_k_params = [[] for _ in range(nb_channels)] self.cin_kfs = [[] for _ in range(nb_channels)] self.cin_gfs = [[] for _ in range(nb_channels)] self.cin_gf_params = [[] for _ in range(nb_channels)] self.kernels_weight_per_channel = [[0.] * nb_kernels for _ in range(nb_channels)] def get_k_params(self) -> jnp.ndarray: """Get all kernel params, flattened Returns: An array of shape ``[nb_kernels, kernel_params_shape...]`` """ flat_k_params = [item for sublist in self.cin_k_params for item in sublist] return jnp.array(flat_k_params, dtype=jnp.float32) def get_gf_params(self) -> jnp.ndarray: """Get all growth function params, flattened Returns: An array of shape ``[nb_kernels, gf_params_shape...]`` """ flat_gf_params = [item for sublist in self.cin_gf_params for item in sublist] return jnp.array(flat_gf_params, dtype=jnp.float32) def get_kernels_weight_per_channel(self) -> jnp.ndarray: """Get all growth function params, flattened Returns: An array of shape ``[nb_channels, nb_kernels]`` """ return jnp.array(self.kernels_weight_per_channel, dtype=jnp.float32)
def get_kernels_and_mapping( kernels_params: List, world_size: List[int], nb_channels: int, R: float, fft: bool = True, ) -> Tuple[jnp.ndarray, KernelMapping]: """Contruxt the kernel array and the associated mapping Build the kernel and the mapping used in the update function by JAX. To take advantages of JAX vmap and conv_generalized function, we stack all kernels and keep track of the kernel graph in mapping: - Each kernel is applied to one channel and the result is added to one other channel Args: kernels_params: Kernels parameters. world_size: World size. nb_channels: Number of channels. R: World radius. fft: Set to ``True`` to compute a FFT kernel. """ kernels_list = [] mapping = KernelMapping(nb_channels, len(kernels_params)) # We need to sort those kernels_params to make sure kernels_weight_per_channel is in the right order kernels_params.sort(key=lambda d: d['c_in']) for kernel_idx, param in enumerate(kernels_params): k = register[param['k_slug']](R, param['k_params'], param['kf_slug'], param['kf_params']) padding = [[0, 0]] for ws, ks in zip(world_size, k.shape[1:]): pad_needed = ws - ks if pad_needed % 2 == 0: pad = [(ws - ks) // 2, (ws - ks) // 2] else: pad = [(ws - ks) // 2, (ws - ks) // 2 + 1] padding.append(pad) kernels_list.append(jnp.pad(k, padding)) channel_in = param["c_in"] mapping.cin_kernels[channel_in].append(kernel_idx) mapping.cin_gfs[channel_in].append(param["gf_slug"]) mapping.cin_gf_params[channel_in].append(param["gf_params"]) mapping.cin_kfs[channel_in].append(param["kf_slug"]) mapping.cin_k_params[channel_in].append(param["k_params"]) channel_out = param["c_out"] mapping.kernels_weight_per_channel[channel_out][kernel_idx] = param["h"] # ! vstack concantenate on the first dimension # Currently it works, because we only considere 1-channel kernels kernels = jnp.vstack(kernels_list) # [nb_kernels, H, W] if fft is False: kernels = leniax_utils.crop_zero(kernels) # [nb_kernels, K_h=max_k_h, K_w=max_k_w] K_h = kernels.shape[-2] K_w = kernels.shape[-1] max_k_per_channel = max(map(len, mapping.cin_kernels)) kernels_per_channel = [] true_channels = [] for cin_kernel_list in mapping.cin_kernels: if len(cin_kernel_list): K_tmp = kernels[jnp.array(cin_kernel_list)] # [O=nb_kernel_per_channel, K_h, K_w] else: K_tmp = jnp.zeros([0, K_h, K_w]) # We create the mask nb_k = K_tmp.shape[0] nb_missing = max_k_per_channel - nb_k true_channels += [True] * nb_k + [False] * nb_missing if nb_missing > 0: K_tmp = jnp.vstack([K_tmp, jnp.zeros([nb_missing, K_h, K_w])]) # [O=max_k_per_channel, K_h, K_w] K_tmp = K_tmp[:, jnp.newaxis, ...] # [O=max_k_per_channel, I=1, K_h, K_w] kernels_per_channel.append(K_tmp) if bool(jnp.prod(jnp.array(true_channels))) is True: mapping.true_channels = None else: mapping.true_channels = true_channels if fft is True: axes = list(range(-1, -len(world_size) - 1, -1)) K = jnp.stack(kernels_per_channel)[jnp.newaxis, :, :, 0, ...] # [1, nb_channels, max_k_per_channel, H, W] K = jnp.fft.fftshift(K, axes=axes) K = jnp.fft.fftn(K, axes=axes) assert K.shape[0] == 1 assert K.shape[1] == nb_channels assert K.shape[2] == max_k_per_channel else: K = jnp.vstack(kernels_per_channel) # [O=nb_channels*max_k_per_channel, I=1, K_h, K_w] assert K.shape[0] == nb_channels * max_k_per_channel assert K.shape[1] == 1 return K, mapping
[docs]def raw(R, k_params: List, kf_slug: str, kf_params: jnp.ndarray) -> jnp.ndarray: """Returns k_params as an array Args: R: World radius **(not used)** k_params: Kernel parameters kf_slug: Kernel function slug kf_params: Kernel function params Returns: A kernel of shape ``k_params`` """ return jnp.array(k_params, dtype=jnp.float32)
[docs]def circle_2d(R, k_params: List, kf_slug: str, kf_params: jnp.ndarray) -> jnp.ndarray: """Build a circle kernel Args: R: World radius k_params: Kernel parameters kf_slug: Kernel function slug kf_params: Kernel function params Returns: A kernel of shape ``[1, world_dims...]`` """ r = k_params[0] bs = jnp.array(k_params[1], dtype=jnp.float32) k_radius_px = math.ceil(r * R) midpoint = jnp.array([k_radius_px, k_radius_px]) midpoint = midpoint.reshape([-1] + [1, 1]) # [nb_dims, 1, 1] coords = jnp.indices([k_radius_px * 2, k_radius_px * 2]) # [nb_dims, dim_0, dim_1] centered_coords = (coords - midpoint) / (r * R) # [nb_dims, dim_0, dim_1] # Distances from the center of the grid distances = jnp.sqrt(jnp.sum(centered_coords**2, axis=0)) # [dim_0, dim_1] nb_b = bs.shape[0] B_dist = nb_b * distances # scale distances by the number of modes bs_mat = bs[jnp.minimum(jnp.floor(B_dist).astype(int), nb_b - 1)] # Define postions for each mode # All kernel functions are defined in [0, 1] so we compute B_dist modulo 1 to make sure value lies in the good range # Notice that those kernel functions should be positive and fk(0) = 0, fk(1) = 0 # We then crop to distances < 1 which defines the "size" of the kernel raw_kernel = kf_register[kf_slug](kf_params, B_dist % 1) kernel = (distances < 1) * raw_kernel * bs_mat # type: ignore kernel = kernel / kernel.sum() kernel = kernel[jnp.newaxis, ...] # [1, dim_0, dim_1, ...] return kernel
def ellipse_2d(R, k_params: List, kf_slug: str, kf_params: jnp.ndarray) -> jnp.ndarray: """Build a circle kernel Args: R: World radius k_params: Kernel parameters kf_slug: Kernel function slug kf_params: Kernel function params Returns: A kernel of shape ``[1, world_dims...]`` """ r = k_params[0] k_radius_px = math.ceil(r * R) bs = jnp.array(k_params[1], dtype=jnp.float32) a = k_params[2] b = k_params[3] theta = k_params[4] * jnp.pi midpoint = jnp.array([k_radius_px, k_radius_px]) midpoint = midpoint.reshape([-1] + [1, 1]) # [nb_dims, 1, 1] coords = jnp.indices([k_radius_px * 2, k_radius_px * 2]) # [nb_dims, dim_0, dim_1] centered_coords = (coords - midpoint) / (r * R) # [nb_dims, dim_0, dim_1] # Rotated distances from the center of the grid rot_centered_coords = jnp.stack([ centered_coords[0] * jnp.cos(theta) + centered_coords[1] * jnp.sin(theta), -centered_coords[0] * jnp.sin(theta) + centered_coords[1] * jnp.cos(theta) ]) distances = jnp.sqrt((rot_centered_coords[0] / a)**2 + (rot_centered_coords[1] / b)**2) nb_b = bs.shape[0] B_dist = nb_b * distances # scale distances by the number of modes bs_mat = bs[jnp.minimum(jnp.floor(B_dist).astype(int), nb_b - 1)] # Define postions for each mode # All kernel functions are defined in [0, 1] so we compute B_dist modulo 1 to make sure value lies in the good range # Notice that those kernel functions should be positive and fk(0) = 0, fk(1) = 0 # We then crop to distances < 1 which defines the "size" of the kernel raw_kernel = kf_register[kf_slug](kf_params, B_dist % 1) kernel = (distances < 1) * raw_kernel * bs_mat # type: ignore kernel = kernel / kernel.sum() grad_coord = rot_centered_coords[0].at[rot_centered_coords[0] < -0.01].set(-1) grad_coord = grad_coord.at[grad_coord > 0.01].set(1) kernel = kernel * grad_coord kernel = kernel[jnp.newaxis, ...] # [1, dim_0, dim_1, ...] return kernel def oriented_ellipse_2d(R, k_params: List, kf_slug: str, kf_params: jnp.ndarray) -> jnp.ndarray: """Build a circle kernel Args: R: World radius k_params: Kernel parameters kf_slug: Kernel function slug kf_params: Kernel function params Returns: A kernel of shape ``[1, world_dims...]`` """ r = k_params[0] k_radius_px = math.ceil(r * R) bs = jnp.array(k_params[1], dtype=jnp.float32) a = k_params[2] b = k_params[3] theta = k_params[4] * jnp.pi midpoint = jnp.array([k_radius_px, k_radius_px]) midpoint = midpoint.reshape([-1] + [1, 1]) # [nb_dims, 1, 1] coords = jnp.indices([k_radius_px * 2, k_radius_px * 2]) # [nb_dims, dim_0, dim_1] centered_coords = (coords - midpoint) / (r * R) # [nb_dims, dim_0, dim_1] # Rotated distances from the center of the grid rot_centered_coords = jnp.stack([ centered_coords[0] * jnp.cos(theta) + centered_coords[1] * jnp.sin(theta), -centered_coords[0] * jnp.sin(theta) + centered_coords[1] * jnp.cos(theta) ]) distances = jnp.sqrt((rot_centered_coords[0] / a)**2 + (rot_centered_coords[1] / b)**2) nb_b = bs.shape[0] B_dist = nb_b * distances # scale distances by the number of modes bs_mat = bs[jnp.minimum(jnp.floor(B_dist).astype(int), nb_b - 1)] # Define postions for each mode # All kernel functions are defined in [0, 1] so we compute B_dist modulo 1 to make sure value lies in the good range # Notice that those kernel functions should be positive and fk(0) = 0, fk(1) = 0 # We then crop to distances < 1 which defines the "size" of the kernel raw_kernel = kf_register[kf_slug](kf_params, B_dist % 1) kernel = (distances < 1) * raw_kernel * bs_mat # type: ignore kernel = kernel * rot_centered_coords[0] kernel = kernel / jnp.absolute(kernel).sum() kernel = kernel[jnp.newaxis, ...] # [1, dim_0, dim_1, ...] return kernel register: Dict[str, Callable] = { 'raw': raw, 'circle_2d': circle_2d, 'ellipse_2d': ellipse_2d, 'oriented_ellipse_2d': oriented_ellipse_2d, }