Source code for leniax.kernel_functions

from typing import Dict, Callable
import jax.numpy as jnp

from .constant import EPSILON


[docs]def poly_quad(params: jnp.ndarray, X: jnp.ndarray): r"""Quadratic polynomial .. math:: y = (q * X * (1 - X))^q .. plot:: import matplotlib.pyplot as plt import numpy as np from leniax.kernel_functions import poly_quad x = np.linspace(0., 1., 100) q = 4 params = np.array([q]) y = poly_quad(params, x) plt.plot(x, y) plt.axhline(y=0, color='grey', linestyle='dotted') plt.title(r'Poly_quad4: q=%.2f'%(q)) plt.show() Args: params: Kernel function parameters X: Raw kernel """ q = params[0] out = 4 * X * (1 - X) out = out**q return out
[docs]def gauss_bump(params: jnp.ndarray, X: jnp.ndarray): r"""Gaussian bump function .. math:: y = e^{q * [q - \frac{1}{X * (1 - X)}]} .. plot:: import matplotlib.pyplot as plt import numpy as np from leniax.kernel_functions import gauss_bump x = np.linspace(0., 1., 100) q = 4 params = np.array([q]) y = gauss_bump(params, x) plt.plot(x, y) plt.axhline(y=0, color='grey', linestyle='dotted') plt.title(r'Gauss bump: q=%.2f'%(q)) plt.show() Args: params: Kernel function parameters X: Raw kernel """ q = params[0] out = q - 1 / (X * (1 - X) + EPSILON) out = jnp.exp(q * out) return out
[docs]def step(params: jnp.ndarray, X: jnp.ndarray): r"""Step function .. math:: y = \begin{cases} 1 ,& \text{if } X \geq q \text{ and } X \leq 1 - q\\ 0 ,& \text{otherwise} \end{cases} .. plot:: import matplotlib.pyplot as plt import numpy as np from leniax.kernel_functions import step x = np.linspace(0., 1., 100) q = 0.3 params = np.array([q]) y = step(params, x) plt.plot(x, y) plt.axhline(y=0, color='grey', linestyle='dotted') plt.title(r'Step: q=%.2f'%(q)) plt.show() Args: params: Kernel function parameters X: Raw kernel """ q = params[0] return (X >= q) * (X <= 1 - q)
[docs]def gauss(params: jnp.ndarray, X: jnp.ndarray): r"""Gauss function .. math:: y = e^{-\frac{1}{2} \left(\frac{X - q}{0.3 * q}\right)^2} .. plot:: import matplotlib.pyplot as plt import numpy as np from leniax.kernel_functions import gauss x = np.linspace(0., 1., 100) q = 0.15 params = np.array([q]) y = gauss(params, x) plt.plot(x, y) plt.axhline(y=0, color='grey', linestyle='dotted') plt.title(r'Gauss: q=%.2f'%(q)) plt.show() Args: params: Kernel function parameters X: Raw kernel """ q = params[0] out = ((X - q) / (0.3 * q))**2 out = jnp.exp(-out / 2) return out
def threshold(params: jnp.ndarray, X: jnp.ndarray): r"""Threshold function .. math:: y = \begin{cases} 1 ,& \text{if } X \geq q 0 ,& \text{otherwise} \end{cases} .. plot:: import matplotlib.pyplot as plt import numpy as np from leniax.kernel_functions import step x = np.linspace(0., 1., 100) q = 0.3 params = np.array([q]) y = step(params, x) plt.plot(x, y) plt.axhline(y=0, color='grey', linestyle='dotted') plt.title(r'Threshold: q=%.2f'%(q)) plt.show() Args: params: Kernel function parameters X: Raw kernel """ q = params[0] return X >= q
[docs]def staircase(params: jnp.ndarray, X: jnp.ndarray): r"""Staircase function .. math:: y = \begin{cases} 0.5 ,& \text{if } X \geq m - s \text{ and } X < m - \frac{s}{2}\\ 1 ,& \text{if } X \geq m - \frac{s}{2} \text{ and } X \leq m + \frac{s}{2}\\ 0.5 ,& \text{if } X > m + \frac{s}{2} \text{ and } X \leq m + s\\ 0 ,& \text{otherwise} \end{cases} .. plot:: import matplotlib.pyplot as plt import numpy as np from leniax.kernel_functions import staircase x = np.linspace(0., 1., 100) m = 0.5 s = 0.1 params = np.array([m, s]) y = staircase(params, x) plt.plot(x, y) plt.axhline(y=0, color='grey', linestyle='dotted') plt.title(r'Staircase: m=%.2f, s=%.2f'%(m, s)) plt.show() Args: params: Kernel function parameters X: Raw kernel """ m = params[0] s = params[1] out = 0.5 * (X >= m - s) * (X < m - s / 2) out += 1 * (X >= m - s / 2) * (X <= m + s / 2) out += 0.5 * (X > m + s / 2) * (X <= m + s) return out
[docs]def triangle(params: jnp.ndarray, X: jnp.ndarray): r"""Gauss function .. math:: y = \begin{cases} \frac{X - (m - s)}{m - (m - s)} ,& \text{if } X \geq m - s \text{ and } X < m\\ \frac{X - (m + s)}{m - (m + s)} ,& \text{if } X \geq m \text{ and } X \leq m + s\\ 0 ,& \text{otherwise} \end{cases} .. plot:: import matplotlib.pyplot as plt import numpy as np from leniax.kernel_functions import triangle x = np.linspace(0., 1., 100) m = 0.5 s = 0.1 params = np.array([m, s]) y = triangle(params, x) plt.plot(x, y) plt.axhline(y=0, color='grey', linestyle='dotted') plt.title(r'Triangle: m=%.2f, s=%.2f'%(m, s)) plt.show() Args: params: Kernel function parameters X: Raw kernel """ m = params[0] s = params[1] left = m - s right = m + s out = (X >= left) * (X < m) * (X - left) / (m - left) out += (X >= m) * (X <= right) * (X - right) / (m - right) return out
register: Dict[str, Callable] = { 'poly_quad': poly_quad, 'gauss_bump': gauss_bump, 'gauss': gauss, 'step': step, 'threshold': threshold, 'staircase': staircase, 'triangle': triangle, }