Source code for pomdp_belief_tracking.pf.importance_sampling

"""Provides a general implementation of importance sampling:

.. autofunction:: general_importance_sample
   :noindex:

Our belief update version of importance sampling calls this function with the
appropriate parameters:

.. autofunction:: importance_sample
   :noindex:

Which can best be created through our construction function (with sane
defaults, otherwise you can also apply partial):

.. autofunction:: create_importance_sampling
   :noindex:

Sequential importance sampling, the application of the belief update over
multiple time steps, often involes :func:`resample` to avoid particle
degeneration. When to resample is not straightforward; we provide a general
condition protocol

.. autoclass:: ResampleCondition
   :noindex:

Lastly, we provide a factory function that combines importance sampling with
resampling

.. autofunction:: create_sequential_importance_sampling
   :noindex:

"""

from __future__ import annotations

from copy import deepcopy
from functools import partial
from timeit import default_timer as timer
from typing import Any, Callable, Iterable, List, Optional, Tuple

from typing_extensions import Protocol

from pomdp_belief_tracking.pf.particle_filter import Particle, ParticleFilter
from pomdp_belief_tracking.pf.types import ProposalDistribution
from pomdp_belief_tracking.types import (
    Action,
    BeliefUpdate,
    Info,
    Observation,
    State,
    StateDistribution,
    TransitionFunction,
)


[docs]class WeightFunction(Protocol): """Signature of a weighting function in :func:`general_importance_sample` .. automethod:: __call__ """
[docs] def __call__(self, proposal: State, sample_ctx: Any, info: Info) -> float: """Weights a ``state`` -> ``proposal`` transition under ``sample_ctx`` :param proposal: proposed (updated) sample :param sample_ctx: context around proposal :param info: global information stored during importance sampling :return: a 0 <= weight <= 1 """
[docs]def general_importance_sample( proposal_distr: ProposalDistribution, weight_func: WeightFunction, particles: Iterable[Particle], ) -> Tuple[ParticleFilter, Info]: """The particle filter implementation of IS The underlying algorithm for importance sampling uses a :class:`~pomdp_belief_tracking.pf.types.ProposalDistribution` and weighting distribution to update a particle filter:: for weight, sample in particles: sample ~ proposal_distr(sample) weight <- weight * weight_func(sample) Returns how long the update took in ``info`` with key "belief_update_runtime" :param proposal_distr: function to propose sample updates :param weight_func: function that weights propsals :param particles: the starting set of particles :return: a new particle set """ info: Info = {} new_particles: List[Particle] = [] t = timer() for state, weight in particles: next_state, ctx = proposal_distr(state, info) weight = weight * weight_func(next_state, ctx, info) new_particles.append(Particle(next_state, weight)) info["belief_update_runtime"] = timer() - t return ParticleFilter.from_particles(new_particles), info
[docs]def resample(pf: ParticleFilter, n: int) -> ParticleFilter: """Samples ``n`` particles from ``distr`` .. todo: This implementation is squared in the number of particles because sampling from particle filter is linear in number of particles. Good excuse for optimization, maybe allow for sampling multiple at a time) :param pf: incoming particle filter :param n: number of desired samples in returned PF :return: the resulting particle filter of resampling ``pf`` """ assert n > 0 return ParticleFilter(list(deepcopy(pf()) for _ in range(n)))
[docs]def importance_sample( transition_func: TransitionFunction, observation_model: Callable[[State, Action, State, Observation], float], n: Optional[int], initial_state_distribution: StateDistribution, a: Action, o: Observation, ) -> Tuple[ParticleFilter, Info]: """Applies :func:`general_importance_sample` on POMDPs Here the ``transition_func`` is used to propose next states, which are weighted according to the ``weight_func`` given ``o``. If ``initial_state_distribution`` is _not_ a particle filter (with a given size), then we sample ``n`` particles with weight 1 to start IS. Otherwise we use the particles in the PF. ``n`` is necessary when ``initial_state_distribution`` is not a :class:`~pomdp_belief_tracking.pf.particle_filter.ParticleFilter`. Otherwise ignored. :param transition_func: the proposal function :param observation_model: the model to weight the probability of generating observation ``o`` :param n: num samples, optional :param initial_state_distribution: the starting distribution :param a: taken action :param o: taken observation :return: updated belief """ # create particles to give to importance sampling if not isinstance(initial_state_distribution, ParticleFilter): assert n and n > 0 initial_state_distribution = ParticleFilter.from_distribution( initial_state_distribution, n ) particles = iter(initial_state_distribution.particles) def prop(s: State, info: Info) -> Tuple[State, Any]: """turns the transition function into a proposal function""" ss = transition_func(s, a) return ss, {"action": a, "state": s, "observation": o} def weighting(proposal: State, sample_ctx: Any, info: Info) -> float: """weights the proposal according to the transition predicting observation""" s, a, o = sample_ctx["state"], sample_ctx["action"], sample_ctx["observation"] return observation_model(s, a, proposal, o) return general_importance_sample(prop, weighting, particles)
[docs]def create_importance_sampling( transition_func: TransitionFunction, observation_model: Callable[[State, Action, State, Observation], float], n: Optional[int], ) -> BeliefUpdate: """Partial function that returns a regular IS belief update A simple wrapper around :func:`~pomdp_belief_tracking.pf.importance_sampling.importance_sample` Here the ``transition_func`` is used to propose next states, which are weighted according to the ``weight_func``. If the belief update is _not_ a particle filter (with a given size), then we sample ``n`` particles with weight 1 to start IS. Otherwise we use the particles in the PF. ``n`` is necessary when ``initial_state_distribution`` is not a :class:`~pomdp_belief_tracking.pf.particle_filter.ParticleFilter`. Otherwise ignored. :param transition_func: how to update states :param observation_model: how to weight transitions :param n: num samples, optional :return: func:`importance_sample` as :class:`pomdp_belief_tracking.types.BeliefUpdate` """ return partial(importance_sample, transition_func, observation_model, n)
[docs]class ResampleCondition(Protocol): """The signature of a resample condition .. automethod:: __call__ Provided implementations: .. autosummary:: :nosignatures: ineffective_sample_size """
[docs] def __call__(self, pf: ParticleFilter) -> bool: """Inspects ``pf`` and decides whether it is time to re-sample :param pf: the particle filter to potentially resample :return: ``True`` if ``pf`` should be resampled """
[docs]def ineffective_sample_size(minimal_size: float, pf: ParticleFilter): """Returns whether the sample size of ``pf`` is lower than ``minimal_size`` When given ``minimal_size`` this implements :class:`ResampleCondition` protocol. Asserts that ``minimal_size`` > 0 Calls :func:`~pomdp_belief_tracking.pf.particle_filter.effective_sample_size` under the hood :param minimal_size: the required sample size for this to return False (> 0) :param pf: the particle filter to test the sample size of :returns: True if ``minimal_size`` > sample size of ``pf`` """ assert minimal_size > 0, f"effective sample size ({minimal_size}) must be positive" return minimal_size > pf.effective_sample_size()
[docs]def create_sequential_importance_sampling( resample_condition: ResampleCondition, transition_func: TransitionFunction, observation_model: Callable[[State, Action, State, Observation], float], n: Optional[int] = None, ) -> BeliefUpdate: """Main entry point of this module to create importance sampling update A simple wrapper combining :func:`resample` (if ``resample_condition`` is met) with :func:`importance_sample` (created by calling :func:`create_importance_sampling` ``n`` is necessary when ``initial_state_distribution`` is not a :class:`~pomdp_belief_tracking.pf.particle_filter.ParticleFilter`. Otherwise ignored. :param resample_condition: when to resample (called before IS) :param transition_func: the transition function to propose particles :param observation_model: the function to weight the new particles :param n: number of desired particles """ IS = create_importance_sampling(transition_func, observation_model, n) def belief_update( p: StateDistribution, a: Action, o: Observation ) -> Tuple[StateDistribution, Info]: """belief_update. :param p: :param a: :param o: """ resampled = False if isinstance(p, ParticleFilter) and resample_condition(p): p = resample(p, len(p)) resampled = True belief, info = IS(p, a, o) info["importance_sampling_resampled"] = resampled return belief, info return belief_update