Source code for fortuna.prob_model.posterior.sgmcmc.sgmcmc_step_schedule

from typing import Callable

import jax.numpy as jnp
import numpy as np

from fortuna.typing import Array

StepSchedule = Callable[[Array], Array]


[docs]def constant_schedule(init_step_size: float) -> StepSchedule: """Create a constant step schedule. Parameters ---------- init_step_size: float The step size. Returns ------- schedule_fn: StepSchedule """ if not init_step_size >= 0: raise ValueError("`init_step_size` should be >= 0.") def schedule(_step: Array): return init_step_size return schedule
[docs]def cosine_schedule(init_step_size: float, total_steps: int) -> StepSchedule: """Create a cosine step schedule. Parameters ---------- init_step_size: float The initial step size. total_steps: int The cycle length, in steps. Returns ------- schedule_fn: StepSchedule """ if not init_step_size >= 0: raise ValueError("`init_step_size` should be >= 0.") if not total_steps > 0: raise ValueError("`total_steps` should be > 0.") def schedule(step: Array): t = step / total_steps return 0.5 * init_step_size * (1 + jnp.cos(t * np.pi)) return schedule
[docs]def polynomial_schedule( a: float = 1.0, b: float = 1.0, gamma: float = 0.55 ) -> StepSchedule: """Create a polynomial step schedule. Parameters ---------- a: float Scale of all step sizes. b: float The stabilization constant. gamma: float The decay rate :math:`\gamma \in (0.5, 1.0]`. Returns ------- schedule_fn: StepSchedule """ if not 0.5 < gamma <= 1.0: raise ValueError("`gamma` should be in (0.5, 1.0] range.") def schedule(step: Array): return a * (b + step) ** (-gamma) return schedule
[docs]def constant_schedule_with_cosine_burnin( init_step_size: float, final_step_size: float, burnin_steps: int ) -> StepSchedule: """Create a constant schedule with cosine burn-in. Parameters ---------- init_step_size: float The initial step size. final_step_size: float The desired final step size. burnin_steps: int The length of burn-in, in steps. Returns ------- schedule_fn: StepSchedule """ if not init_step_size >= 0: raise ValueError("`init_step_size` should be >= 0.") if not final_step_size >= 0: raise ValueError("`final_step_size` should be >= 0.") if not burnin_steps >= 0: raise ValueError("`burnin_steps` should be >= 0.") def schedule(step: Array): t = jnp.minimum(step / burnin_steps, 1.0) coef = (1 + jnp.cos(t * np.pi)) * 0.5 return coef * init_step_size + (1 - coef) * final_step_size return schedule
[docs]def cyclical_cosine_schedule_with_const_burnin( init_step_size: float, burnin_steps: int, cycle_length: int ) -> StepSchedule: """Create a cyclical cosine schedule with constant burn-in. Parameters ---------- init_step_size: float The initial step size. burnin_steps: int The length of burn-in, in steps. cycle_length: int The length of the cosine cycle, in steps. Returns ------- schedule_fn: StepSchedule """ if not init_step_size >= 0: raise ValueError("`init_step_size` should be >= 0.") if not burnin_steps >= 0: raise ValueError("`burnin_steps` should be >= 0.") if not cycle_length >= 0: raise ValueError("`cycle_length` should be >= 0.") def schedule(step: Array): t = jnp.maximum(step - burnin_steps - 1, 0.0) t = (t % cycle_length) / cycle_length return 0.5 * init_step_size * (1 + jnp.cos(t * np.pi)) return schedule