Skip to content

bordax.training

bordax.training.trainer

Trainer

Orchestrates the full training loop.

Combines an environment, agent, and algorithm into a training loop that runs for a fixed number of checkpoints. Handles JIT compilation strategy, optional evaluation, logging, and checkpointing.

For on-policy algorithms with a Gymnax environment the entire train_step (collect + update) is JIT-compiled. For off-policy algorithms or Gymnasium environments only the update step is compiled.

Source code in bordax/training/trainer.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
class Trainer:
    """Orchestrates the full training loop.

    Combines an environment, agent, and algorithm into a training loop
    that runs for a fixed number of checkpoints. Handles JIT compilation
    strategy, optional evaluation, logging, and checkpointing.

    For on-policy algorithms with a Gymnax environment the entire
    ``train_step`` (collect + update) is JIT-compiled. For off-policy
    algorithms or Gymnasium environments only the ``update`` step is
    compiled.
    """

    def __init__(
        self,
        env: EnvAdapter,
        eval_env: EnvAdapter,
        agent: Agent,
        algo: Algorithm,
        config: TrainerConfig,
    ):
        """
        Args:
            env: Vectorised training environment (``num_envs`` parallel instances).
            eval_env: Single-instance evaluation environment (``num_envs=1``).
            agent: Agent defining the policy and value networks.
            algo: Algorithm providing collector, batch builder, and updater.
            config: Training configuration.
        """
        self.env = env
        self.eval_env = eval_env
        self.agent = agent
        self.algo = algo
        self.config = config
        self.replay_buffer = None  # For off-policy algorithms
        self.evaluator = Evaluator(eval_env, agent, config)
        if config.logger_config:
            self.logs_enabled = True
            self.logger_config = config.logger_config
            self.logger = Logger(self.logger_config)
        else:
            self.logs_enabled = False
        if config.checkpointer_config:
            self.checkpoints_enabled = True
            self.checkpointer_config = config.checkpointer_config
            self.checkpointer = Checkpointer(self.checkpointer_config)
        else:
            self.checkpoints_enabled = False

    def init(self, key: PRNGKey):
        """Initialise the training state.

        Resets the environment, initialises network parameters via
        ``algo.init_training_state``, and optionally restores a checkpoint
        or fills the replay buffer (off-policy).

        Args:
            key: JAX random key.
        """
        key, env_key, init_key = jax.random.split(key, 3)
        self.last_obs, self.last_env_state = self.env.reset(env_key)
        self.training_state = self.algo.init_training_state(
            self.agent, init_key, self.last_obs, self.env
        )

        if self.config.restore_checkpoint:
            restored_state = self.checkpointer.load(self.training_state, self.config.restore_checkpoint)
            self.training_state = restored_state

        # Evaluation environment must be single-environment (num_envs=1)
        assert self.eval_env.num_envs == 1, f"eval_env must have num_envs=1, got {self.eval_env.num_envs}"

        # Initialize replay buffer for off-policy algorithms
        if self.config.replay_buffer_capacity is not None:
            from bordax.data.buffer import ReplayBuffer
            obs_shape = self.env.obs_space().shape
            action_shape = self.env.action_space().shape
            self.replay_buffer = ReplayBuffer(
                capacity=self.config.replay_buffer_capacity,
                obs_shape=obs_shape,
                action_shape=action_shape
            )

            # Warmup: fill buffer with initial transitions
            if self.config.warmup_steps is not None and self.config.warmup_steps > 0:
                print(f"Warming up replay buffer with {self.config.warmup_steps} transitions...")
                for i in range(self.config.warmup_steps):
                    key, collect_key = jax.random.split(key)
                    (self.last_obs, self.last_env_state), self.replay_buffer = self.algo.collect(
                        collect_key, self.env, self.last_obs, self.last_env_state, 
                        self.replay_buffer, self.agent, self.training_state
                    )
                    if (i + 1) % 200 == 0 and self.config.debug:
                        print(f"  Warmup: {i+1}/{self.config.warmup_steps}, Buffer size: {len(self.replay_buffer)}")
                print(f"Buffer filled with {len(self.replay_buffer)} transitions\n")

    def _run_epoch(
        self,
        key: PRNGKey,
        train_step_fn: Optional[Callable],
    ) -> Tuple[PRNGKey, Any]:
        """Run a single training epoch."""
        if train_step_fn is not None:
            # JIT-compiled path (on-policy, jittable env)
            (
                key,
                self.training_state,
                _,
                self.last_obs,
                self.last_env_state,
            ), metrics = train_step_fn(
                key,
                self.training_state,
                None,
                self.last_obs,
                self.last_env_state,
            )
        else:
            # Non-JIT path (off-policy or non-jittable env)
            (
                key,
                self.training_state,
                self.replay_buffer,
                self.last_obs,
                self.last_env_state,
            ), metrics = self.algo.train_step(
                self.env,
                self.agent,
                key,
                self.training_state,
                self.replay_buffer,
                self.last_obs,
                self.last_env_state,
            )
        return key, metrics

    def _run_checkpoint(self, training_key: PRNGKey, evaluate_key: PRNGKey, ckpt: int, train_step_fn, epoch_rollouts):

        # On-policy with jittable environment
        for epoch in range(self.config.epochs_per_checkpoint):
            training_key, metrics = self._run_epoch(training_key, train_step_fn)

            if self.logs_enabled:
                # Log training metrics
                self.logger.log_metrics(
                    {f"train/{k}": float(v) for k, v in metrics.items()},
                    step=ckpt * self.config.epochs_per_checkpoint + epoch,
                )


        if self.config.enable_evaluation:
            eval_result = self.evaluator.evaluate(evaluate_key, self.training_state.params)
            eval_result = jax.tree_util.tree_map(np.asarray, eval_result)

            eval_returns = eval_result["return"]
            eval_lengths = eval_result["length"]
            done_info = eval_result.get("done_info", None)
            avg_return = float(np.mean(eval_returns))
            avg_length = float(np.mean(eval_lengths))
            if done_info is not None: # average additional info if available
                avg_done_info = {k: float(np.mean([info[k] for info in done_info])) for k in done_info[0]}
            else:
                avg_done_info = {}

            if self.logs_enabled:
                entry = {
                    "eval/avg_return": avg_return,
                    "eval/avg_length": avg_length,
                }
                entry.update({f"eval/done_info/{k}": v for k, v in avg_done_info.items()})
                self.logger.log_evaluation(
                    entry,
                    step=ckpt,
                )

            # Append eval results to epoch_rollouts for return
            epoch_rollouts.append(eval_result)
        else:
            epoch_rollouts.append({})


    def run(self, key: PRNGKey):
        """Run the full training loop.

        Iterates for ``config.num_checkpoints`` checkpoints, each running
        ``config.epochs_per_checkpoint`` training epochs. Evaluates the
        policy after each checkpoint if ``config.enable_evaluation=True``.

        Args:
            key: JAX random key.

        Returns:
            List of evaluation result dicts (one per checkpoint). Each dict
            contains ``"return"`` and ``"length"`` arrays over episodes.
            Empty dicts are appended when evaluation is disabled.
        """
        if self.config.debug:
            pbar = tqdm(
                initial=0 + (0 if self.config.restore_checkpoint is None else self.config.restore_checkpoint),
                total=self.config.num_checkpoints + (0 if self.config.restore_checkpoint is None else self.config.restore_checkpoint))
        else:
            pbar = None

        # Calculate total timesteps based on whether we have a replay buffer
        rollout_len = getattr(self.algo.collector, 'rollout_length', 1)

        print(
            "Total number of timesteps: ",
            self.config.num_checkpoints
            * self.config.epochs_per_checkpoint
            * rollout_len,
        )

        key, training_key, evaluate_key = jax.random.split(key, 3)

        # For on-policy algorithms with jittable envs, we can JIT the entire train_step
        # For off-policy algorithms, train_step internally handles the non-jittable buffer
        train_step = None
        if self.env.is_jittable and self.replay_buffer is None:
            train_step_fixed = functools.partial(
                self.algo.train_step, self.env, self.agent
            )
            train_step = jax.jit(train_step_fixed)

        epoch_rollouts = []

        for ckpt in range(self.config.num_checkpoints):
            training_key, ckpt_training_key = jax.random.split(training_key)
            evaluate_key, ckpt_evaluate_key = jax.random.split(evaluate_key)
            current_epoch = ckpt + (self.config.restore_checkpoint or 0)
            self._run_checkpoint(ckpt_training_key, ckpt_evaluate_key, current_epoch, train_step, epoch_rollouts)
            if self.checkpoints_enabled:
                self.checkpointer.save(self.training_state, current_epoch+1)

            if pbar is not None:
                pbar.update(1)

        return epoch_rollouts

__init__(env, eval_env, agent, algo, config)

Parameters:

Name Type Description Default
env EnvAdapter

Vectorised training environment (num_envs parallel instances).

required
eval_env EnvAdapter

Single-instance evaluation environment (num_envs=1).

required
agent Agent

Agent defining the policy and value networks.

required
algo Algorithm

Algorithm providing collector, batch builder, and updater.

required
config TrainerConfig

Training configuration.

required
Source code in bordax/training/trainer.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def __init__(
    self,
    env: EnvAdapter,
    eval_env: EnvAdapter,
    agent: Agent,
    algo: Algorithm,
    config: TrainerConfig,
):
    """
    Args:
        env: Vectorised training environment (``num_envs`` parallel instances).
        eval_env: Single-instance evaluation environment (``num_envs=1``).
        agent: Agent defining the policy and value networks.
        algo: Algorithm providing collector, batch builder, and updater.
        config: Training configuration.
    """
    self.env = env
    self.eval_env = eval_env
    self.agent = agent
    self.algo = algo
    self.config = config
    self.replay_buffer = None  # For off-policy algorithms
    self.evaluator = Evaluator(eval_env, agent, config)
    if config.logger_config:
        self.logs_enabled = True
        self.logger_config = config.logger_config
        self.logger = Logger(self.logger_config)
    else:
        self.logs_enabled = False
    if config.checkpointer_config:
        self.checkpoints_enabled = True
        self.checkpointer_config = config.checkpointer_config
        self.checkpointer = Checkpointer(self.checkpointer_config)
    else:
        self.checkpoints_enabled = False

init(key)

Initialise the training state.

Resets the environment, initialises network parameters via algo.init_training_state, and optionally restores a checkpoint or fills the replay buffer (off-policy).

Parameters:

Name Type Description Default
key PRNGKey

JAX random key.

required
Source code in bordax/training/trainer.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def init(self, key: PRNGKey):
    """Initialise the training state.

    Resets the environment, initialises network parameters via
    ``algo.init_training_state``, and optionally restores a checkpoint
    or fills the replay buffer (off-policy).

    Args:
        key: JAX random key.
    """
    key, env_key, init_key = jax.random.split(key, 3)
    self.last_obs, self.last_env_state = self.env.reset(env_key)
    self.training_state = self.algo.init_training_state(
        self.agent, init_key, self.last_obs, self.env
    )

    if self.config.restore_checkpoint:
        restored_state = self.checkpointer.load(self.training_state, self.config.restore_checkpoint)
        self.training_state = restored_state

    # Evaluation environment must be single-environment (num_envs=1)
    assert self.eval_env.num_envs == 1, f"eval_env must have num_envs=1, got {self.eval_env.num_envs}"

    # Initialize replay buffer for off-policy algorithms
    if self.config.replay_buffer_capacity is not None:
        from bordax.data.buffer import ReplayBuffer
        obs_shape = self.env.obs_space().shape
        action_shape = self.env.action_space().shape
        self.replay_buffer = ReplayBuffer(
            capacity=self.config.replay_buffer_capacity,
            obs_shape=obs_shape,
            action_shape=action_shape
        )

        # Warmup: fill buffer with initial transitions
        if self.config.warmup_steps is not None and self.config.warmup_steps > 0:
            print(f"Warming up replay buffer with {self.config.warmup_steps} transitions...")
            for i in range(self.config.warmup_steps):
                key, collect_key = jax.random.split(key)
                (self.last_obs, self.last_env_state), self.replay_buffer = self.algo.collect(
                    collect_key, self.env, self.last_obs, self.last_env_state, 
                    self.replay_buffer, self.agent, self.training_state
                )
                if (i + 1) % 200 == 0 and self.config.debug:
                    print(f"  Warmup: {i+1}/{self.config.warmup_steps}, Buffer size: {len(self.replay_buffer)}")
            print(f"Buffer filled with {len(self.replay_buffer)} transitions\n")

run(key)

Run the full training loop.

Iterates for config.num_checkpoints checkpoints, each running config.epochs_per_checkpoint training epochs. Evaluates the policy after each checkpoint if config.enable_evaluation=True.

Parameters:

Name Type Description Default
key PRNGKey

JAX random key.

required

Returns:

Type Description

List of evaluation result dicts (one per checkpoint). Each dict

contains "return" and "length" arrays over episodes.

Empty dicts are appended when evaluation is disabled.

Source code in bordax/training/trainer.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def run(self, key: PRNGKey):
    """Run the full training loop.

    Iterates for ``config.num_checkpoints`` checkpoints, each running
    ``config.epochs_per_checkpoint`` training epochs. Evaluates the
    policy after each checkpoint if ``config.enable_evaluation=True``.

    Args:
        key: JAX random key.

    Returns:
        List of evaluation result dicts (one per checkpoint). Each dict
        contains ``"return"`` and ``"length"`` arrays over episodes.
        Empty dicts are appended when evaluation is disabled.
    """
    if self.config.debug:
        pbar = tqdm(
            initial=0 + (0 if self.config.restore_checkpoint is None else self.config.restore_checkpoint),
            total=self.config.num_checkpoints + (0 if self.config.restore_checkpoint is None else self.config.restore_checkpoint))
    else:
        pbar = None

    # Calculate total timesteps based on whether we have a replay buffer
    rollout_len = getattr(self.algo.collector, 'rollout_length', 1)

    print(
        "Total number of timesteps: ",
        self.config.num_checkpoints
        * self.config.epochs_per_checkpoint
        * rollout_len,
    )

    key, training_key, evaluate_key = jax.random.split(key, 3)

    # For on-policy algorithms with jittable envs, we can JIT the entire train_step
    # For off-policy algorithms, train_step internally handles the non-jittable buffer
    train_step = None
    if self.env.is_jittable and self.replay_buffer is None:
        train_step_fixed = functools.partial(
            self.algo.train_step, self.env, self.agent
        )
        train_step = jax.jit(train_step_fixed)

    epoch_rollouts = []

    for ckpt in range(self.config.num_checkpoints):
        training_key, ckpt_training_key = jax.random.split(training_key)
        evaluate_key, ckpt_evaluate_key = jax.random.split(evaluate_key)
        current_epoch = ckpt + (self.config.restore_checkpoint or 0)
        self._run_checkpoint(ckpt_training_key, ckpt_evaluate_key, current_epoch, train_step, epoch_rollouts)
        if self.checkpoints_enabled:
            self.checkpointer.save(self.training_state, current_epoch+1)

        if pbar is not None:
            pbar.update(1)

    return epoch_rollouts

TrainerConfig dataclass

Configuration for the Trainer.

Attributes:

Name Type Description
num_checkpoints int

Total number of training checkpoints (outer loop iterations). Each checkpoint runs epochs_per_checkpoint epochs and optionally evaluates the policy.

epochs_per_checkpoint int

Number of training epochs per checkpoint.

evaluation_episodes int

Number of episodes to average over during evaluation. Ignored if enable_evaluation=False.

logger_config Optional[LoggerConfig]

Optional LoggerConfig for WandB logging. If None, logging is disabled.

checkpointer_config Optional[Any]

Optional config for Orbax checkpointing. If None, checkpointing is disabled.

restore_checkpoint Optional[int]

If set, restores parameters from the given checkpoint index before training starts.

debug bool

If True, shows a tqdm progress bar during training.

replay_buffer_capacity Optional[int]

Capacity of the replay buffer for off-policy algorithms. If None, on-policy mode is assumed.

warmup_steps Optional[int]

Number of environment steps to collect into the replay buffer before training begins (off-policy only).

enable_evaluation bool

If False, skips policy evaluation at each checkpoint (useful for warmup or ablation runs).

Source code in bordax/training/trainer.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@dataclass
class TrainerConfig:
    """Configuration for the ``Trainer``.

    Attributes:
        num_checkpoints: Total number of training checkpoints (outer loop
            iterations). Each checkpoint runs ``epochs_per_checkpoint``
            epochs and optionally evaluates the policy.
        epochs_per_checkpoint: Number of training epochs per checkpoint.
        evaluation_episodes: Number of episodes to average over during
            evaluation. Ignored if ``enable_evaluation=False``.
        logger_config: Optional ``LoggerConfig`` for WandB logging.
            If ``None``, logging is disabled.
        checkpointer_config: Optional config for Orbax checkpointing.
            If ``None``, checkpointing is disabled.
        restore_checkpoint: If set, restores parameters from the given
            checkpoint index before training starts.
        debug: If ``True``, shows a tqdm progress bar during training.
        replay_buffer_capacity: Capacity of the replay buffer for
            off-policy algorithms. If ``None``, on-policy mode is assumed.
        warmup_steps: Number of environment steps to collect into the
            replay buffer before training begins (off-policy only).
        enable_evaluation: If ``False``, skips policy evaluation at each
            checkpoint (useful for warmup or ablation runs).
    """

    num_checkpoints: int
    epochs_per_checkpoint: int
    evaluation_episodes: int
    logger_config: Optional[LoggerConfig] = None
    checkpointer_config: Optional[Any] = None
    restore_checkpoint: Optional[int] = None
    debug: bool = False
    replay_buffer_capacity: Optional[int] = None
    warmup_steps: Optional[int] = None
    enable_evaluation: bool = True

bordax.training.updaters

bordax.training.evaluation