Source code for fortuna.utils.random

from jax import random
from jax._src.prng import PRNGKeyArray
from jax.tree_util import (
    tree_map,
    tree_structure,
    tree_unflatten,
)
from optax._src.base import PyTree


def generate_rng_like_tree(rng, target: PyTree):
    treedef = tree_structure(target)
    keys = random.split(rng, treedef.num_leaves)
    return tree_unflatten(treedef, keys)


def generate_random_normal_like_tree(rng, target: PyTree):
    keys = generate_rng_like_tree(rng, target)
    return tree_map(
        lambda l, k: random.normal(k, l.shape, l.dtype),
        target,
        keys,
    )


[docs]class RandomNumberGenerator: def __init__(self, seed: int): """ A random number generator object. Parameters ---------- seed : int A random seed. """ self._rng = random.PRNGKey(seed)
[docs] def get(self) -> PRNGKeyArray: """ Get the internal random number generator key. Whenever this function is called, the random number generator key is updated. Returns ------- PRNGKeyArray A random number generator key. """ self._rng = random.split(self._rng)[0] return self._rng
class WithRNG: @property def rng(self) -> RandomNumberGenerator: """ Invoke the random number generator object. Returns ------- The random number generator object. """ return self._rng @rng.setter def rng(self, rng: RandomNumberGenerator): """ Set a random number generator object. Parameters ---------- rng : RandomNumberGenerator A random number generator object. """ self._rng = rng