Skip to content

bordax.algorithms

bordax.algorithms.base

Algorithm

Bases: NamedTuple

A training algorithm composed of a collector, batch builder, and updater.

Attributes:

Name Type Description
collector Collector

Generates transitions by interacting with the environment.

batch_builder BatchBuilder

Transforms collected data into training batches.

updater Updater

Applies gradient updates to the network parameters.

Source code in bordax/algorithms/base.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class Algorithm(NamedTuple):
    """A training algorithm composed of a collector, batch builder, and updater.

    Attributes:
        collector: Generates transitions by interacting with the environment.
        batch_builder: Transforms collected data into training batches.
        updater: Applies gradient updates to the network parameters.
    """

    collector: Collector
    batch_builder: BatchBuilder
    updater: Updater

    def init_training_state(
        self, agent: Agent, key: PRNGKey, sample_obs: Any, env: EnvAdapter
    ) -> TrainingState:
        """Initialise the training state for a given agent.

        Args:
            agent: The agent whose parameters are initialised.
            key: JAX random key.
            sample_obs: A sample observation used to infer network input shapes.
            env: The training environment (used by some updaters).

        Returns:
            A ``TrainingState`` containing initial parameters and optimizer state.
        """
        params = agent.init(key, sample_obs)
        return self.updater.init(params)

    def collect(
        self,
        key: PRNGKey,
        env: EnvAdapter,
        obs: EnvObs,
        env_state: EnvState,
        replay_buffer: Any,
        agent: Agent,
        ts: TrainingState,
    ):
        """Collect experience from the environment.

        Delegates to ``self.collector``. For on-policy algorithms the
        returned buffer contains the freshly collected rollout; for
        off-policy algorithms transitions are appended to the existing
        replay buffer which is returned.

        Returns:
            Tuple of ``((obs, env_state), replay_buffer)``.
        """
        return self.collector(key, env, obs, env_state, replay_buffer, agent, ts)

    @functools.partial(jax.jit, static_argnames=("self", "agent"))
    def update(self, agent: Agent, batch: Any, ts: TrainingState, key: PRNGKey):
        """JIT-compiled parameter update step.

        Args:
            agent: Agent providing loss function access.
            batch: Training batch produced by the batch builder.
            ts: Current training state.
            key: JAX random key.

        Returns:
            Tuple of ``(new_training_state, metrics_dict)``.
        """
        return self.updater(
            agent,
            batch,
            ts,
            key,
        )

    def train_step(
        self,
        env: EnvAdapter,
        agent: Agent,
        key: PRNGKey,
        ts: TrainingState,
        replay_buffer: Any,
        obs: EnvObs,
        env_state: EnvState,
    ):
        """Run one full training iteration: collect → batch → update.

        This method is JIT-compiled by the ``Trainer`` when the environment
        is jittable and the algorithm is on-policy.

        Returns:
            Tuple of ``((key, ts, replay_buffer, obs, env_state), metrics)``.
        """
        key, collect_key, batch_key, update_key = jax.random.split(key, 4)

        (obs, env_state), replay_buffer = self.collect(
            collect_key, env, obs, env_state, replay_buffer, agent, ts
        ) 
        # the collector also updates the replay buffer:
        # for on-policy, it returns the new buffer with the collected rollout; 
        # for off-policy, it adds the new transitions to the existing buffer

        batch = self.batch_builder(batch_key, replay_buffer)
        ts, metrics = self.update(agent, batch, ts, update_key)

        return (key, ts, replay_buffer, obs, env_state), metrics

collect(key, env, obs, env_state, replay_buffer, agent, ts)

Collect experience from the environment.

Delegates to self.collector. For on-policy algorithms the returned buffer contains the freshly collected rollout; for off-policy algorithms transitions are appended to the existing replay buffer which is returned.

Returns:

Type Description

Tuple of ((obs, env_state), replay_buffer).

Source code in bordax/algorithms/base.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def collect(
    self,
    key: PRNGKey,
    env: EnvAdapter,
    obs: EnvObs,
    env_state: EnvState,
    replay_buffer: Any,
    agent: Agent,
    ts: TrainingState,
):
    """Collect experience from the environment.

    Delegates to ``self.collector``. For on-policy algorithms the
    returned buffer contains the freshly collected rollout; for
    off-policy algorithms transitions are appended to the existing
    replay buffer which is returned.

    Returns:
        Tuple of ``((obs, env_state), replay_buffer)``.
    """
    return self.collector(key, env, obs, env_state, replay_buffer, agent, ts)

init_training_state(agent, key, sample_obs, env)

Initialise the training state for a given agent.

Parameters:

Name Type Description Default
agent Agent

The agent whose parameters are initialised.

required
key PRNGKey

JAX random key.

required
sample_obs Any

A sample observation used to infer network input shapes.

required
env EnvAdapter

The training environment (used by some updaters).

required

Returns:

Type Description
TrainingState

A TrainingState containing initial parameters and optimizer state.

Source code in bordax/algorithms/base.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def init_training_state(
    self, agent: Agent, key: PRNGKey, sample_obs: Any, env: EnvAdapter
) -> TrainingState:
    """Initialise the training state for a given agent.

    Args:
        agent: The agent whose parameters are initialised.
        key: JAX random key.
        sample_obs: A sample observation used to infer network input shapes.
        env: The training environment (used by some updaters).

    Returns:
        A ``TrainingState`` containing initial parameters and optimizer state.
    """
    params = agent.init(key, sample_obs)
    return self.updater.init(params)

train_step(env, agent, key, ts, replay_buffer, obs, env_state)

Run one full training iteration: collect → batch → update.

This method is JIT-compiled by the Trainer when the environment is jittable and the algorithm is on-policy.

Returns:

Type Description

Tuple of ((key, ts, replay_buffer, obs, env_state), metrics).

Source code in bordax/algorithms/base.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def train_step(
    self,
    env: EnvAdapter,
    agent: Agent,
    key: PRNGKey,
    ts: TrainingState,
    replay_buffer: Any,
    obs: EnvObs,
    env_state: EnvState,
):
    """Run one full training iteration: collect → batch → update.

    This method is JIT-compiled by the ``Trainer`` when the environment
    is jittable and the algorithm is on-policy.

    Returns:
        Tuple of ``((key, ts, replay_buffer, obs, env_state), metrics)``.
    """
    key, collect_key, batch_key, update_key = jax.random.split(key, 4)

    (obs, env_state), replay_buffer = self.collect(
        collect_key, env, obs, env_state, replay_buffer, agent, ts
    ) 
    # the collector also updates the replay buffer:
    # for on-policy, it returns the new buffer with the collected rollout; 
    # for off-policy, it adds the new transitions to the existing buffer

    batch = self.batch_builder(batch_key, replay_buffer)
    ts, metrics = self.update(agent, batch, ts, update_key)

    return (key, ts, replay_buffer, obs, env_state), metrics

update(agent, batch, ts, key)

JIT-compiled parameter update step.

Parameters:

Name Type Description Default
agent Agent

Agent providing loss function access.

required
batch Any

Training batch produced by the batch builder.

required
ts TrainingState

Current training state.

required
key PRNGKey

JAX random key.

required

Returns:

Type Description

Tuple of (new_training_state, metrics_dict).

Source code in bordax/algorithms/base.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@functools.partial(jax.jit, static_argnames=("self", "agent"))
def update(self, agent: Agent, batch: Any, ts: TrainingState, key: PRNGKey):
    """JIT-compiled parameter update step.

    Args:
        agent: Agent providing loss function access.
        batch: Training batch produced by the batch builder.
        ts: Current training state.
        key: JAX random key.

    Returns:
        Tuple of ``(new_training_state, metrics_dict)``.
    """
    return self.updater(
        agent,
        batch,
        ts,
        key,
    )

dqn_algo(epsilon_schedule=lambda t: 0.1, rollout_length=1, batch_size=32, gamma=0.99, lr=0.0001, target_update_freq=1000, applied_loss=optax.squared_error, **kwargs)

Create a DQN algorithm.

Parameters:

Name Type Description Default
epsilon_schedule Callable[[int], float]

Callable (step) -> epsilon controlling the exploration rate over time.

lambda t: 0.1
rollout_length int

Number of environment steps collected per update. Typically 1 for standard DQN.

1
batch_size int

Number of transitions sampled from the replay buffer per update.

32
gamma float

Discount factor for Bellman targets.

0.99
lr float

Adam learning rate for the Q-network.

0.0001
target_update_freq int

Number of training steps between target network hard updates.

1000
applied_loss Callable

Element-wise loss applied to TD errors (e.g. optax.squared_error or optax.huber_loss).

squared_error

Returns:

Type Description

A configured Algorithm for DQN.

Source code in bordax/algorithms/base.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def dqn_algo(
    epsilon_schedule: Callable[[int], float] = lambda t: 0.1,
    rollout_length: int = 1,
    batch_size: int = 32,
    gamma: float = 0.99,
    lr: float = 1e-4,
    target_update_freq: int = 1000,
    applied_loss: Callable = optax.squared_error,
    **kwargs
):
    """Create a DQN algorithm.

    Args:
        epsilon_schedule: Callable ``(step) -> epsilon`` controlling the
            exploration rate over time.
        rollout_length: Number of environment steps collected per update.
            Typically 1 for standard DQN.
        batch_size: Number of transitions sampled from the replay buffer
            per update.
        gamma: Discount factor for Bellman targets.
        lr: Adam learning rate for the Q-network.
        target_update_freq: Number of training steps between target network
            hard updates.
        applied_loss: Element-wise loss applied to TD errors (e.g.
            ``optax.squared_error`` or ``optax.huber_loss``).

    Returns:
        A configured ``Algorithm`` for DQN.
    """

    return Algorithm(
        EpsGreedyCollector(epsilon_schedule=epsilon_schedule, rollout_length=rollout_length),
        UniformReplayBatch(batch_size),
        DQNUpdater(
            optimizer=optax.adam(lr),
            loss_fn=DQNLoss(gamma=gamma, applied_loss=applied_loss),
            target_update_freq=target_update_freq,
        ),
    )

ppo_algo(rollout_length=1024, gamma=0.99, _lambda=0.85, lr=0.001, clip_schedule=lambda _: 0.2, vf_schedule=lambda _: 0.5, ent_schedule=lambda _: 0.01, num_minibatches=16, num_sgd_steps=1, num_envs=1, grad_clip=0.5, **kwargs)

Create a PPO algorithm.

Parameters:

Name Type Description Default
rollout_length int

Number of environment steps collected per epoch per environment. Must be divisible by num_minibatches.

1024
gamma float

Discount factor for returns.

0.99
_lambda float

GAE lambda for advantage estimation.

0.85
lr float

Adam learning rate.

0.001
clip_schedule

Callable (step) -> clip_ratio. Defaults to constant 0.2.

lambda _: 0.2
vf_schedule

Callable (step) -> vf_coef. Defaults to 0.5.

lambda _: 0.5
ent_schedule

Callable (step) -> ent_coef. Defaults to 0.01.

lambda _: 0.01
num_minibatches

Number of minibatches to split each rollout into.

16
num_sgd_steps

Number of SGD passes over the data per epoch.

1
num_envs int

Number of parallel environments (used for batch reshaping).

1
grad_clip float

Global gradient norm clipping threshold.

0.5

Returns:

Type Description

A configured Algorithm for PPO.

Source code in bordax/algorithms/base.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def ppo_algo(
    rollout_length: int = 1024,
    gamma: float = 0.99,
    _lambda: float = 0.85,
    lr: float = 0.001,
    clip_schedule=lambda _: 0.2,
    vf_schedule=lambda _: 0.5,
    ent_schedule=lambda _: 0.01,
    num_minibatches=16,
    num_sgd_steps=1,
    num_envs: int = 1,
    grad_clip: float = 0.5,
    **kwargs
):
    """Create a PPO algorithm.

    Args:
        rollout_length: Number of environment steps collected per epoch
            per environment. Must be divisible by ``num_minibatches``.
        gamma: Discount factor for returns.
        _lambda: GAE lambda for advantage estimation.
        lr: Adam learning rate.
        clip_schedule: Callable ``(step) -> clip_ratio``. Defaults to
            constant 0.2.
        vf_schedule: Callable ``(step) -> vf_coef``. Defaults to 0.5.
        ent_schedule: Callable ``(step) -> ent_coef``. Defaults to 0.01.
        num_minibatches: Number of minibatches to split each rollout into.
        num_sgd_steps: Number of SGD passes over the data per epoch.
        num_envs: Number of parallel environments (used for batch reshaping).
        grad_clip: Global gradient norm clipping threshold.

    Returns:
        A configured ``Algorithm`` for PPO.
    """

    assert (
        rollout_length % num_minibatches == 0
    ), "Rollout length must be divisible by number of minibatches"

    schedule = optax.constant_schedule(lr)
    adam = optax.inject_hyperparams(optax.adam)(learning_rate=schedule)
    optimizier = optax.chain(optax.clip_by_global_norm(grad_clip), adam)    

    return Algorithm(
        OnPolicyCollector(rollout_length, gamma, _lambda),
        ComposedBatchBuilder(
            (
                FullBufferBatch(rollout_length, num_envs),
                MiniBatch(num_minibatches),
                NormalizeAdvantagesTargets(normalize_targets=False),
            ),
        ),
        SGDUpdate(
            optimizer=optimizier,
            loss_fn=PPOLoss(
                clip_schedule=clip_schedule,
                vf_coef_schedule=vf_schedule,
                ent_coef_schedule=ent_schedule,
            ),
            num_sgd_steps=num_sgd_steps,
            grad_clip=grad_clip,
        ),
    )

bordax.algorithms.losses

bordax.algorithms.utils

make_algo(algo_name, algo_config={})

Create an algorithm by name.

Parameters:

Name Type Description Default
algo_name str

Algorithm identifier. Supported values:

  • "ppo" — Proximal Policy Optimization (on-policy)
  • "dqn" — Deep Q-Network (off-policy)
required
algo_config dict

Dict of hyperparameters forwarded to the algorithm factory function. See ppo_algo and dqn_algo for the accepted keys.

{}

Returns:

Type Description
Algorithm

A configured Algorithm instance.

Raises:

Type Description
ValueError

If algo_name is not in the registry.

Source code in bordax/algorithms/utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def make_algo(algo_name: str, algo_config: dict = {}) -> Algorithm:
    """Create an algorithm by name.

    Args:
        algo_name: Algorithm identifier. Supported values:

            - ``"ppo"`` — Proximal Policy Optimization (on-policy)
            - ``"dqn"`` — Deep Q-Network (off-policy)

        algo_config: Dict of hyperparameters forwarded to the algorithm
            factory function. See ``ppo_algo`` and ``dqn_algo`` for the
            accepted keys.

    Returns:
        A configured ``Algorithm`` instance.

    Raises:
        ValueError: If ``algo_name`` is not in the registry.
    """
    try:
        alg = ALGO_REGISTRY[algo_name]
    except KeyError:
        raise ValueError(
            f"Algo {algo_name} is not supported. Supported algos are: {list(ALGO_REGISTRY.keys())}"
        )
    return alg(**algo_config)