UsageΒΆ

Here we give an example on how to track the belief in the canonical tiger problem. First we define the environment:

class Tiger:
    """The Tiger POMDP environment"""

    L = 0
    R = 1
    H = 2

    H_REWARD = -1
    OPEN_CORRECT_REWARD = 10
    OPEN_INCORRECT_REWARD = -100

    @staticmethod
    def sample_observation(s: State) -> Observation:
        """85% hear tiger correctly"""
        if random.uniform(0, 1) < 0.85:
            return s
        return int(not s)

    @staticmethod
    def sim(s: State, a: Action) -> Tuple[State, Observation]:
        """Simulates the tiger dynamics"""

        if a == Tiger.H:
            o = Tiger.sample_observation(s)
            return s, o

        assert a in [Tiger.L, Tiger.R]

        o = random.choice([Tiger.L, Tiger.R])
        s = random.choice([Tiger.L, Tiger.R])

        return s, o

    @staticmethod
    def observation_model(a: Action, next_s: State) -> List[float]:
        """Returns the observation probabilities a, next_s' pair

        :param next_s: next state
        :param a: taken action
        :return: [prob hearing left, prob hearing right]
        """
        if a != Tiger.H:
            return [0.5, 0.5]

        if next_s == Tiger.L:
            return [0.85, 0.15]

        assert next_s == Tiger.R, f"How did {next_s} become a state?"

        return [0.15, 0.85]

Then given some beliefs:

def uniform_tiger_belief():
    """Sampling returns 'left' and 'right' state equally"""
    return random.choice([Tiger.L, Tiger.R])

We can update the belief according to, for example, rejection sampling:

from pomdp_belief_tracking.pf.rejection_sampling import (
    ParticleFilter,
    accept_noop,
    create_rejection_sampling,
    reject_noop,
)

belief_update = create_rejection_sampling(
    Tiger.sim,
    100,
    process_acpt=accept_noop,
    process_rej=reject_noop,
)

b, run_time_info = belief_update(uniform_tiger_belief, Tiger.H, Tiger.L)

Or importance sampling:

from pomdp_belief_tracking.pf.importance_sampling import create_importance_sampling

n = 100

def trans_func(s, a):
    return Tiger.sim(s, a)[0]

def obs_func(s, a, ss, o):
    return Tiger.observation_model(a, ss)[o]

belief_update = create_importance_sampling(trans_func, obs_func, n)

b, run_time_info = belief_update(uniform_tiger_belief, Tiger.H, Tiger.L)