Source code for leniax.loader

# import time
import io
import binascii
import itertools
import os
import codecs
import pickle
import gzip
import jax.numpy as jnp
import numpy as np
from typing import List, Any, Dict

from .constant import NB_CHARS


[docs]def make_array_compressible(cells: jnp.ndarray) -> jnp.ndarray: """Round values so the array can be encoded using a subset of utf-8 characters Args: cells: Cells state. Returns: The rounded cells state """ max_val = NB_CHARS**2 - 1 cells_int32 = jnp.array(jnp.round(cells * max_val), dtype=jnp.int32) compressible_cells = jnp.array(cells_int32 / max_val, dtype=jnp.float32) return compressible_cells
[docs]def compress_array(cells: jnp.ndarray) -> str: """Compress a cells state into a base64 utf-8 string. .. note:: The cells state in float32 is first encoded as int32. That state is then flattened and converted into raw bytes of length 4 in little endian. Finally we prepend the total number of bytes of the state and append the shape as bytes. Finally, we compress the array using the gzip algorithm and the resulting bytes are encoded as base64 in the utf-8 encoding. Args: cells: Cells state Returns: Cells state encoded as a string. """ max_val = NB_CHARS**2 - 1 cells_int32 = jnp.array(jnp.round(cells * max_val), dtype=jnp.int32) cells_shape = cells_int32.shape cells_shape_bytes = bytes(0) for i in cells_shape: cells_shape_bytes += int.to_bytes(i, length=4, byteorder='little') cells_int32_flatten = cells_int32.flatten() nb_int32_bytes = int.to_bytes(len(cells_int32_flatten), length=4, byteorder='little') gzip_bytes = gzip.compress(nb_int32_bytes + cells_int32_flatten.tobytes() + cells_shape_bytes) b64_str = str(codecs.encode(gzip_bytes, 'base64'), 'utf-8') return b64_str
[docs]def decompress_array(string_cells: str, nb_dims: int = 0) -> jnp.ndarray: """Best effort helpers which tries all existing decompress function built so far Args: string_cells: A string encoded cells state. nb_dims: the number of dimensions in the cells state. Returns: The decoded cells state array. """ try: cells = decompress_array_gzip(string_cells) except Exception: try: string_array = string_cells.split('::') if len(string_array) != 2 and len(string_array[0]) % 2 == 0: raise Exception() max_val = NB_CHARS**2 - 1 cells_shape = [int(c) for c in string_array[-1].split(";")] cells_val_l = [ch2val(string_array[0][i:i + 2]) for i in range(0, len(string_array[0]), 2)] cells_int32 = jnp.array(cells_val_l, dtype=jnp.int32).reshape(cells_shape) cells = jnp.array(cells_int32 / max_val, dtype=jnp.float32) except Exception: try: string_bytes = io.BytesIO(string_cells.encode('latin1')) cells_uint8 = np.load(string_bytes)['x'] # type: ignore cells = cells_uint8 / 255. except Exception: try: cells = decompress_array_base64(string_cells) except binascii.Error: cells = deprecated_decompress_array(string_cells, nb_dims) return jnp.array(cells, dtype=jnp.float32)
[docs]def decompress_array_gzip(string_cells: str) -> jnp.ndarray: """Decompress string encoded cells state using the gzip algorithm Args: string_cells: A base64 string encoded cells state. Returns: The decoded cells state array. """ gzip_bytes = codecs.decode(bytes(string_cells, 'utf-8'), 'base64') cells_int32_bytes = gzip.decompress(gzip_bytes) max_val = NB_CHARS**2 - 1 int_list = [] nb_int32 = int.from_bytes(cells_int32_bytes[0:4], 'little') for i in range(1 * 4, (1 + nb_int32) * 4, 4): int_list.append(int.from_bytes(cells_int32_bytes[i:i + 4], 'little')) shape = [] for i in range((1 + nb_int32) * 4, len(cells_int32_bytes), 4): shape.append(int.from_bytes(cells_int32_bytes[i:i + 4], 'little')) cells_int32 = jnp.array(int_list, dtype=jnp.int32).reshape(shape) cells = jnp.array(cells_int32 / max_val, dtype=jnp.float32) return cells
[docs]def decompress_array_base64(string_cells: str) -> jnp.ndarray: """Decompress string encoded cells state using only the base64 algorithm Args: string_cells: A base64 string encoded cells state. Returns: The decoded cells state array. """ serialized_cells = codecs.decode(string_cells.encode(), "base64") cells = pickle.loads(serialized_cells) return jnp.array(cells, dtype=jnp.float32)
[docs]def ch2val(c: str) -> int: """Map characters to integers Args: c: A character. Returns: An integer. """ assert len(c) == 2 first_char = c[0] second_char = c[1] if ord(first_char) >= ord('À'): first_char_idx = ord(first_char) - ord('À') + (ord('Z') - ord('A')) + (ord('z') - ord('a')) elif ord(first_char) >= ord('a'): first_char_idx = ord(first_char) - ord('a') + (ord('Z') - ord('A')) else: first_char_idx = ord(first_char) - ord('A') if ord(second_char) >= ord('À'): second_char_idx = ord(second_char) - ord('À') + (ord('Z') - ord('A')) + (ord('z') - ord('a')) elif ord(second_char) >= ord('a'): second_char_idx = ord(second_char) - ord('a') + (ord('Z') - ord('A')) else: second_char_idx = ord(second_char) - ord('A') return first_char_idx * NB_CHARS + second_char_idx
[docs]def val2ch(v: int) -> str: """Map integers to characters Args: v: An integer Returns: A character. """ first_char_idx = v // NB_CHARS second_char_idx = v % NB_CHARS # We do this trick to avoid the special characters between if ord('A') + first_char_idx <= ord('Z'): first_char = chr(ord('A') + first_char_idx) elif ord('a') + first_char_idx - (ord('Z') - ord('A')) <= ord('z'): first_char = chr(ord('a') + first_char_idx - (ord('Z') - ord('A'))) else: first_char = chr(ord('À') + first_char_idx - (ord('Z') - ord('A')) - (ord('z') - ord('a'))) if ord('A') + second_char_idx <= ord('Z'): second_char = chr(ord('A') + second_char_idx) elif ord('a') + second_char_idx - (ord('Z') - ord('A')) <= ord('z'): second_char = chr(ord('a') + second_char_idx - (ord('Z') - ord('A'))) else: second_char = chr(ord('À') + second_char_idx - (ord('Z') - ord('A')) - (ord('z') - ord('a'))) return first_char + second_char
[docs]def load_raw_cells(config: Dict, use_init_cells: bool = True) -> jnp.ndarray: """Load and decompress cells state contained in a Leniax configuration. Args: config: Leniax configuration use_init_cells: Set to ``True`` to use the ``init_cells`` configuration property. Returns: A Leniax cells state. """ nb_dims = config['world_params']['nb_dims'] if use_init_cells is True: if 'init_cells' not in config['run_params']: # Backward compatibility cells = config['run_params']['cells'] else: cells = config['run_params']['init_cells'] else: cells = config['run_params']['cells'] if type(cells) is str: if cells == 'MISSING': cells = jnp.array([], dtype=jnp.float32) elif cells == 'last_frame.p': with open(os.path.join(config['main_path'], 'last_frame.p'), 'rb') as f: cells = jnp.array(pickle.load(f), dtype=jnp.float32) else: cells = decompress_array(cells, nb_dims + 1) # we add the channel dim elif type(cells) is list: cells = jnp.array(cells, dtype=jnp.float32) # We repair the missing channel in case of the single channel got squeezed out if len(cells.shape) == nb_dims and config['world_params']['nb_channels'] == 1: cells = jnp.expand_dims(cells, 0) return cells
### # Deprecated, kept for backward compatibilities ### DIM_DELIM = {0: '', 1: '$', 2: '%', 3: '#', 4: '@A', 5: '@B', 6: '@C', 7: '@D', 8: '@E', 9: '@F'} def append_stack(list1: List, elem: Any, count, is_repeat=False): list1.append(elem) if count != '': repeated = elem if is_repeat else [] list1.extend([repeated] * (int(count) - 1)) def recur_get_max_lens(dim, list1, max_lens, nb_dims: int): max_lens[dim] = max(max_lens[dim], len(list1)) if dim < nb_dims - 1: for list2 in list1: recur_get_max_lens(dim + 1, list2, max_lens, nb_dims) def recur_cubify(dim, list1, max_lens, nb_dims: int): more = max_lens[dim] - len(list1) if dim < nb_dims - 1: list1.extend([[]] * more) for list2 in list1: recur_cubify(dim + 1, list2, max_lens, nb_dims) else: list1.extend([0] * more) def _recur_drill_list(dim, lists, row_func, nb_dims: int): if dim < nb_dims - 1: return [_recur_drill_list(dim + 1, e, row_func, nb_dims) for e in lists] else: return row_func(lists) def _recur_join_st(dim, lists, row_func, nb_dims: int): if dim < nb_dims - 1: return DIM_DELIM[nb_dims - 1 - dim].join(_recur_join_st(dim + 1, e, row_func, nb_dims) for e in lists) else: return DIM_DELIM[nb_dims - 1 - dim].join(row_func(lists)) def char2val_deprecated(ch: str) -> int: if ch in '.b': return 0 elif ch == 'o': return 255 elif len(ch) == 1: return ord(ch) - ord('A') + 1 else: return (ord(ch[0]) - ord('p')) * 24 + (ord(ch[1]) - ord('A') + 25) def val2char_deprecated(v: int) -> str: if v == 0: return '.' elif v < 25: return chr(ord('A') + v - 1) else: return chr(ord('p') + (v - 25) // 24) + chr(ord('A') + (v - 25) % 24) def deprecated_compress_array(cells): def drill_row(row): return [(len(list(g)), val2char_deprecated(v).strip()) for v, g in itertools.groupby(row)] def join_row_shorten(row): return [(str(n) if n > 1 else '') + c for n, c in row] nb_dims = len(cells.shape) cells_int_l = np.rint(cells * 255).astype(int).tolist() # [[255 255] [255 0]] rle_groups = _recur_drill_list(0, cells_int_l, drill_row, nb_dims) st = _recur_join_st(0, rle_groups, join_row_shorten, nb_dims) # "2 yO $ 1 yO" return st + '!' def deprecated_decompress_array(cells_code: str, nb_dims: int) -> jnp.ndarray: stacks: List[List] = [[] for dim in range(nb_dims)] last, count = '', '' delims = list(DIM_DELIM.values()) st = cells_code.rstrip('!') + DIM_DELIM[nb_dims - 1] for ch in st: if ch.isdigit(): count += ch elif ch in 'pqrstuvwxy@': last = ch else: if last + ch not in delims: append_stack(stacks[0], char2val_deprecated(last + ch) / 255, count, is_repeat=True) else: dim = delims.index(last + ch) for d in range(dim): append_stack(stacks[d + 1], stacks[d], count, is_repeat=False) stacks[d] = [] last, count = '', '' cells_l = stacks[nb_dims - 1] max_lens = [0 for dim in range(nb_dims)] recur_get_max_lens(0, cells_l, max_lens, nb_dims) recur_cubify(0, cells_l, max_lens, nb_dims) cells = jnp.array(cells_l, dtype=jnp.float32) return cells