Skip to content

bordax.data

bordax.data.collectors

Collector

Bases: ABC

Abstract base class for data collectors.

A collector interacts with the environment for a fixed number of steps and returns the resulting transitions, optionally storing them in a replay buffer.

Source code in bordax/data/collectors.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Collector(ABC):
    """Abstract base class for data collectors.

    A collector interacts with the environment for a fixed number of steps
    and returns the resulting transitions, optionally storing them in a
    replay buffer.
    """

    @abstractmethod
    def __call__(
        self,
        key: PRNGKey,
        env: EnvAdapter,
        obs: EnvObs,
        env_state: EnvState,
        replay_buffer: Any,
        agent: Agent,
        ts: TrainingState,
    ) -> Tuple[Tuple[Any, EnvState], Any]:
        """Collect transitions from the environment.

        Args:
            key: JAX random key.
            env: The environment to interact with.
            obs: Current observation batch.
            env_state: Current environment state batch.
            replay_buffer: Existing replay buffer (on-policy: ignored;
                off-policy: transitions are appended to it).
            agent: Agent used to select actions.
            ts: Current training state (provides parameters).

        Returns:
            Tuple of ``((next_obs, next_env_state), buffer)`` where
            ``buffer`` is a trajectory dict (on-policy) or the updated
            ``ReplayBuffer`` (off-policy).
        """
        ...

__call__(key, env, obs, env_state, replay_buffer, agent, ts) abstractmethod

Collect transitions from the environment.

Parameters:

Name Type Description Default
key PRNGKey

JAX random key.

required
env EnvAdapter

The environment to interact with.

required
obs EnvObs

Current observation batch.

required
env_state EnvState

Current environment state batch.

required
replay_buffer Any

Existing replay buffer (on-policy: ignored; off-policy: transitions are appended to it).

required
agent Agent

Agent used to select actions.

required
ts TrainingState

Current training state (provides parameters).

required

Returns:

Type Description
Tuple[Any, EnvState]

Tuple of ((next_obs, next_env_state), buffer) where

Any

buffer is a trajectory dict (on-policy) or the updated

Tuple[Tuple[Any, EnvState], Any]

ReplayBuffer (off-policy).

Source code in bordax/data/collectors.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@abstractmethod
def __call__(
    self,
    key: PRNGKey,
    env: EnvAdapter,
    obs: EnvObs,
    env_state: EnvState,
    replay_buffer: Any,
    agent: Agent,
    ts: TrainingState,
) -> Tuple[Tuple[Any, EnvState], Any]:
    """Collect transitions from the environment.

    Args:
        key: JAX random key.
        env: The environment to interact with.
        obs: Current observation batch.
        env_state: Current environment state batch.
        replay_buffer: Existing replay buffer (on-policy: ignored;
            off-policy: transitions are appended to it).
        agent: Agent used to select actions.
        ts: Current training state (provides parameters).

    Returns:
        Tuple of ``((next_obs, next_env_state), buffer)`` where
        ``buffer`` is a trajectory dict (on-policy) or the updated
        ``ReplayBuffer`` (off-policy).
    """
    ...

EpsGreedyCollector

Bases: Collector

Collects transitions using an epsilon-greedy policy (for DQN).

At each step, with probability epsilon a random action is taken; otherwise the greedy action from the Q-network is used. Collected transitions are added to the replay buffer one at a time.

Source code in bordax/data/collectors.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
class EpsGreedyCollector(Collector):
    """Collects transitions using an epsilon-greedy policy (for DQN).

    At each step, with probability ``epsilon`` a random action is taken;
    otherwise the greedy action from the Q-network is used. Collected
    transitions are added to the replay buffer one at a time.
    """

    def __init__(self, epsilon_schedule: Callable[[int], float], rollout_length: int = 1):
        """
        Args:
            epsilon_schedule: Callable ``(step) -> epsilon`` that controls
                exploration over training.
            rollout_length: Number of environment steps collected per call.
                Typically 1 for standard DQN.
        """
        self.epsilon_schedule = epsilon_schedule
        self.rollout_length = rollout_length

    @functools.partial(jax.jit, static_argnames=("self", "agent", "env"))
    def _jit_collect(self, key: PRNGKey, env: EnvAdapter, obs: EnvObs, 
                            env_state: EnvState, agent: Agent, params: Params, epsilon: float):
        def one_step(carry, unused):
            key, obs, env_state = carry
            key, explore_key, act_key, env_key = jax.random.split(key, 4)
            do_explore = jax.random.uniform(explore_key) < epsilon
            action, _ = agent.action(params, obs, act_key)
            if hasattr(env.action_space(), 'n'):
                random_action = jax.random.randint(act_key, action.shape, 0, env.action_space().n)
            else:
                random_action = jax.random.uniform(act_key, action.shape, 
                                                   minval=env.action_space().low,
                                                   maxval=env.action_space().high)

            action = jax.lax.select(do_explore, random_action, action)
            n_obs, n_env_state, reward, done, _ = env.step(env_key, env_state, action)

            transition = {
                'obs': obs,
                'action': action,
                'reward': reward,
                'next_obs': n_obs,
                'done': done
            }

            return (key, n_obs, n_env_state), transition

        (key, final_obs, final_state), transitions = jax.lax.scan(
            one_step,
            (key, obs, env_state),
            None,
            length=self.rollout_length
        )

        return (final_obs, final_state), transitions

    def _non_jittable_collect(self, key: PRNGKey, env: EnvAdapter, obs: EnvObs,
                              env_state: EnvState, agent: Agent, params: Params,
                              epsilon: float):
        """Collect transitions using epsilon-greedy policy for non-jittable environments."""
        env_spec = dict(
            obs_shape=env.obs_space().shape,
            action_shape=env.action_space().shape
        )

        is_discrete = hasattr(env.action_space(), 'n')
        action_dtype = np.int32 if is_discrete else np.float32

        buffer = {
            "obs": np.zeros(
                (self.rollout_length, env.num_envs) + env_spec["obs_shape"],
                dtype=np.float32,
            ),
            "action": np.zeros(
                (self.rollout_length, env.num_envs) + env_spec["action_shape"],
                dtype=action_dtype,
            ),
            "reward": np.zeros(
                (self.rollout_length, env.num_envs),
                dtype=np.float32,
            ),
            "next_obs": np.zeros(
                (self.rollout_length, env.num_envs) + env_spec["obs_shape"],
                dtype=np.float32,
            ),
            "done": np.zeros(
                (self.rollout_length, env.num_envs),
                dtype=np.bool_,
            ),
        }

        for i in range(self.rollout_length):
            key, explore_key, act_key, env_key = jax.random.split(key, 4)

            buffer["obs"][i] = np.asarray(obs)

            do_explore = float(jax.random.uniform(explore_key)) < epsilon

            if do_explore:
                if is_discrete:
                    action = np.random.randint(0, env.action_space().n, size=(env.num_envs,))
                else:
                    action = np.random.uniform(
                        low=env.action_space().low,
                        high=env.action_space().high,
                        size=(env.num_envs,) + env_spec["action_shape"]
                    ).astype(np.float32)
            else:
                action, _ = agent.action(params, obs, act_key)
                action = np.asarray(action)

            n_obs, n_env_state, reward, done, _ = env.step(env_key, env_state, action)

            buffer["action"][i] = action
            buffer["reward"][i] = np.asarray(reward)
            buffer["next_obs"][i] = np.asarray(n_obs)
            buffer["done"][i] = np.asarray(done)

            obs = n_obs
            env_state = n_env_state

        transitions = jax.tree_util.tree_map(jnp.asarray, buffer)
        return (obs, env_state), transitions

    def __call__(self, key: PRNGKey, env: EnvAdapter, obs: EnvObs, env_state: EnvState, 
                 replay_buffer: Any, agent: Agent, ts: TrainingState) -> Tuple[Tuple[Any, EnvState], Any]:

        epsilon = self.epsilon_schedule(ts.step.item())

        if env.is_jittable:
            (obs, env_state), transitions = self._jit_collect(key, env, obs, env_state, agent, ts.params, epsilon)
        else:
            (obs, env_state), transitions = self._non_jittable_collect(
                key, env, obs, env_state, agent, ts.params, epsilon
            )

        transitions_np = jax.tree_util.tree_map(np.asarray, transitions)

        for i in range(self.rollout_length):
            transition = jax.tree_util.tree_map(lambda x: x[i:i+1], transitions_np)
            replay_buffer.add(transition)

        return (obs, env_state), replay_buffer

__init__(epsilon_schedule, rollout_length=1)

Parameters:

Name Type Description Default
epsilon_schedule Callable[[int], float]

Callable (step) -> epsilon that controls exploration over training.

required
rollout_length int

Number of environment steps collected per call. Typically 1 for standard DQN.

1
Source code in bordax/data/collectors.py
245
246
247
248
249
250
251
252
253
254
def __init__(self, epsilon_schedule: Callable[[int], float], rollout_length: int = 1):
    """
    Args:
        epsilon_schedule: Callable ``(step) -> epsilon`` that controls
            exploration over training.
        rollout_length: Number of environment steps collected per call.
            Typically 1 for standard DQN.
    """
    self.epsilon_schedule = epsilon_schedule
    self.rollout_length = rollout_length

OnPolicyCollector

Bases: Collector

Collects full rollouts for on-policy algorithms (e.g. PPO).

For jittable environments the rollout is gathered inside jax.lax.scan, keeping everything on-device. For non-jittable environments a Python loop is used instead, with a final device transfer. GAE advantages and value targets are computed after collection.

Source code in bordax/data/collectors.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class OnPolicyCollector(Collector):
    """Collects full rollouts for on-policy algorithms (e.g. PPO).

    For jittable environments the rollout is gathered inside
    ``jax.lax.scan``, keeping everything on-device. For non-jittable
    environments a Python loop is used instead, with a final device
    transfer. GAE advantages and value targets are computed after
    collection.
    """

    def __init__(
        self, rollout_length: int = 1024, gamma: float = 0.99, _lambda: float = 0.99
    ):
        """
        Args:
            rollout_length: Number of environment steps per rollout.
            gamma: Discount factor used in GAE computation.
            _lambda: GAE lambda parameter controlling the bias-variance tradeoff.
        """
        self.rollout_length = rollout_length
        self.gamma = gamma
        self._lambda = _lambda

    @functools.partial(jax.jit, static_argnames=("self", "agent", "env"))
    def collect_jittable(self, key, env, obs, env_state, agent: Agent, params):
        init_obs, init_state = obs, env_state

        def one_step(carry, unused):
            key, obs, env_state = carry
            key, act_key, env_key = jax.random.split(key, 3)
            action, info = agent.action(params, obs, act_key)

            n_obs, n_env_state, reward, done, env_info = env.step(
                env_key, env_state, action
            )

            transition = dict(
                obs=obs,
                action=action,
                reward=reward,
                done=done,
                info=info,
            )

            return (key, n_obs, n_env_state), transition

        (key, last_obs, last_env_state), traj = jax.lax.scan(
            one_step,
            (key, init_obs, init_state),
            None,
            length=self.rollout_length,
        )

        return (key, last_obs, last_env_state), traj

    def collect_non_jittable(
        self, key, env: EnvAdapter, obs, env_state, agent: Agent, params
    ):
        # Since non-jittable environments are executed on CPU,
        # it makes sense to save data in numpy arrays directly
        # and eventually convert to JAX arrays and send them to device

        # It is, however, possibly a bottleneck as the data transfer for each step
        # between the device and the host may be slow

        env_spec = dict(
            obs_shape=env.obs_space().shape, action_shape=env.action_space().shape
        )

        buffer = {
            "obs": np.zeros(
                (self.rollout_length, env.num_envs) + env_spec["obs_shape"],
                dtype=np.float32,
            ),
            "action": np.zeros(
                (self.rollout_length, env.num_envs) + env_spec["action_shape"],
                dtype=np.int32,
            ),
            "value": np.zeros((self.rollout_length, env.num_envs), dtype=np.float32),
            "reward": np.zeros((self.rollout_length, env.num_envs), dtype=np.float32),
            "done": np.zeros((self.rollout_length, env.num_envs), dtype=np.bool_),
            "info": {
                "logp": np.zeros((self.rollout_length, env.num_envs), dtype=np.float32)
            },
        }

        for i in range(self.rollout_length):
            key, act_key, env_key = jax.random.split(key, 3)

            buffer["obs"][i] = obs


            action, action_info = agent.action(params, obs, act_key) # Agent methods return jax arrays
            # value = agent.value(params, obs)

            n_obs, n_env_state, reward, done, env_info = env.step(
                act_key, env_state, np.asarray(action)
            )

            buffer["action"][i] = np.asarray(action)
            # buffer["value"][i] = np.asarray(value)
            buffer["reward"][i] = np.asarray(reward)
            buffer["done"][i] = np.asarray(done)
            buffer["info"]["logp"][i] = np.asarray(action_info["logp"])

            obs = n_obs
            env_state = n_env_state
        # convert all collected observations to jax array and get the values from the agent
        values = agent.value(params, jnp.asarray(buffer["obs"]))
        buffer["value"] = np.asarray(values)

        # Convert everything to jax arrays
        traj = jax.tree_util.tree_map(jnp.asarray, buffer)

        return (obs, env_state), traj

    def __call__(self, key, env, obs, env_state, replay_buffer: Any, agent: Agent, ts: TrainingState):
        # Collect rollout (returns JAX arrays in both cases)
        if env.is_jittable:
            (key, last_obs, last_env_state), traj = self.collect_jittable(
                key, env, obs, env_state, agent, ts.params
            )
        else:
            (last_obs, last_env_state), traj = self.collect_non_jittable(
                key, env, obs, env_state, agent, ts.params
            )

        # Calculate GAE (traj is JAX arrays here)
        last_value = agent.value(ts.params, last_obs)
        values = agent.value(ts.params, traj["obs"])

        advantages, targets = jax.lax.stop_gradient(
            compute_gae(traj, last_value, values, self.gamma, self._lambda)
        )

        traj["advantages"] = advantages
        traj["targets"] = targets

        return (last_obs, last_env_state), traj

__init__(rollout_length=1024, gamma=0.99, _lambda=0.99)

Parameters:

Name Type Description Default
rollout_length int

Number of environment steps per rollout.

1024
gamma float

Discount factor used in GAE computation.

0.99
_lambda float

GAE lambda parameter controlling the bias-variance tradeoff.

0.99
Source code in bordax/data/collectors.py
63
64
65
66
67
68
69
70
71
72
73
74
def __init__(
    self, rollout_length: int = 1024, gamma: float = 0.99, _lambda: float = 0.99
):
    """
    Args:
        rollout_length: Number of environment steps per rollout.
        gamma: Discount factor used in GAE computation.
        _lambda: GAE lambda parameter controlling the bias-variance tradeoff.
    """
    self.rollout_length = rollout_length
    self.gamma = gamma
    self._lambda = _lambda

compute_gae(traj_batch, last_value, values, gamma, gae_lambda)

Compute Generalised Advantage Estimates (GAE) for a rollout.

Parameters:

Name Type Description Default
traj_batch

Dict of trajectory arrays with shape (rollout_length, num_envs, ...).

required
last_value

Value estimate for the observation after the last step, shape (num_envs,).

required
values

Value estimates for all observations in the rollout, shape (rollout_length, num_envs).

required
gamma

Discount factor.

required
gae_lambda

GAE lambda parameter.

required

Returns:

Type Description

Tuple of (advantages, targets) each with shape

(rollout_length, num_envs).

Source code in bordax/data/collectors.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
@jax.jit
def compute_gae(traj_batch, last_value, values, gamma, gae_lambda):
    """Compute Generalised Advantage Estimates (GAE) for a rollout.

    Args:
        traj_batch: Dict of trajectory arrays with shape
            ``(rollout_length, num_envs, ...)``.
        last_value: Value estimate for the observation after the last step,
            shape ``(num_envs,)``.
        values: Value estimates for all observations in the rollout,
            shape ``(rollout_length, num_envs)``.
        gamma: Discount factor.
        gae_lambda: GAE lambda parameter.

    Returns:
        Tuple of ``(advantages, targets)`` each with shape
        ``(rollout_length, num_envs)``.
    """

    def _get_advantages(gae_and_next_value, transition):

        gae, next_value = gae_and_next_value
        transition, value = transition
        done, reward = (
            transition["done"],
            transition["reward"],
        )

        delta = reward + gamma * next_value * (1 - done) - value
        gae = delta + gamma * gae_lambda * (1 - done) * gae

        return (gae, value), gae

    _, advantages = jax.lax.scan(
        _get_advantages,
        (jnp.zeros_like(last_value), last_value),
        (traj_batch, values),
        reverse=True,
    )

    return advantages, advantages + values

bordax.data.batchbuilders

BatchBuilder

Bases: ABC

Abstract base class for batch builders.

A batch builder transforms a raw buffer (trajectory dict or replay buffer) into the format expected by the updater. Batch builders can be chained via ComposedBatchBuilder.

Source code in bordax/data/batchbuilders.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class BatchBuilder(ABC):
    """Abstract base class for batch builders.

    A batch builder transforms a raw buffer (trajectory dict or replay
    buffer) into the format expected by the updater. Batch builders can
    be chained via ``ComposedBatchBuilder``.
    """

    @abstractmethod
    def __call__(
        self, key: PRNGKey, buffer: Any
    ) -> Tuple[PRNGKey, Mapping[str, jnp.ndarray]]:
        """Transform a buffer into a training batch.

        Args:
            key: JAX random key (for shuffling or sampling).
            buffer: Raw data — a trajectory dict (on-policy) or a
                ``ReplayBuffer`` instance (off-policy).

        Returns:
            A batch dict of JAX arrays ready for the updater.
        """
        ...

__call__(key, buffer) abstractmethod

Transform a buffer into a training batch.

Parameters:

Name Type Description Default
key PRNGKey

JAX random key (for shuffling or sampling).

required
buffer Any

Raw data — a trajectory dict (on-policy) or a ReplayBuffer instance (off-policy).

required

Returns:

Type Description
Tuple[PRNGKey, Mapping[str, ndarray]]

A batch dict of JAX arrays ready for the updater.

Source code in bordax/data/batchbuilders.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@abstractmethod
def __call__(
    self, key: PRNGKey, buffer: Any
) -> Tuple[PRNGKey, Mapping[str, jnp.ndarray]]:
    """Transform a buffer into a training batch.

    Args:
        key: JAX random key (for shuffling or sampling).
        buffer: Raw data — a trajectory dict (on-policy) or a
            ``ReplayBuffer`` instance (off-policy).

    Returns:
        A batch dict of JAX arrays ready for the updater.
    """
    ...

ComposedBatchBuilder

Bases: BatchBuilder

Apply a sequence of batch builders in order.

Each builder's output is passed as input to the next. The full composed call is JIT-compiled. Typical PPO usage::

ComposedBatchBuilder((
    FullBufferBatch(rollout_length, num_envs),
    MiniBatch(num_minibatches),
    NormalizeAdvantagesTargets(),
))
Source code in bordax/data/batchbuilders.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class ComposedBatchBuilder(BatchBuilder):
    """Apply a sequence of batch builders in order.

    Each builder's output is passed as input to the next. The full
    composed call is JIT-compiled. Typical PPO usage::

        ComposedBatchBuilder((
            FullBufferBatch(rollout_length, num_envs),
            MiniBatch(num_minibatches),
            NormalizeAdvantagesTargets(),
        ))
    """

    def __init__(self, batch_builders: Sequence[BatchBuilder]):
        """
        Args:
            batch_builders: Ordered sequence of batch builders to apply.
        """
        self.batch_builders = batch_builders
    @functools.partial(jax.jit, static_argnames=("self"))
    def __call__(
        self, key: PRNGKey, buffer: Any
    ) -> Tuple[PRNGKey, Mapping[str, jnp.ndarray]]:
        keys = jax.random.split(key, len(self.batch_builders))
        for i, batch_builder in enumerate(self.batch_builders):
            buffer = batch_builder(keys[i], buffer)

        return buffer

__init__(batch_builders)

Parameters:

Name Type Description Default
batch_builders Sequence[BatchBuilder]

Ordered sequence of batch builders to apply.

required
Source code in bordax/data/batchbuilders.py
148
149
150
151
152
153
def __init__(self, batch_builders: Sequence[BatchBuilder]):
    """
    Args:
        batch_builders: Ordered sequence of batch builders to apply.
    """
    self.batch_builders = batch_builders

FullBufferBatch

Bases: BatchBuilder

Flatten and shuffle an entire on-policy rollout into a single batch.

Merges the time and environment dimensions, then applies a random permutation. Typically the first stage in a ComposedBatchBuilder for PPO, followed by MiniBatch.

Source code in bordax/data/batchbuilders.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class FullBufferBatch(BatchBuilder):
    """Flatten and shuffle an entire on-policy rollout into a single batch.

    Merges the time and environment dimensions, then applies a random
    permutation. Typically the first stage in a ``ComposedBatchBuilder``
    for PPO, followed by ``MiniBatch``.
    """

    def __init__(self, buffer_size, num_env):
        """
        Args:
            buffer_size: Number of timesteps in the rollout.
            num_env: Number of parallel environments.
        """
        self.buffer_size = buffer_size
        self.num_env = num_env

    def __call__(self, key: PRNGKey, buffer: Any) -> Tuple[PRNGKey, Mapping[str, jnp.ndarray]]:
        # Sample a batch from the buffer
        # the buffer is a (possibly nested) dictionary with entries of the shape (Time, Batch, shape)

        key, perm_key = jax.random.split(key, 2)

        # flatten the batch from several environments
        batch_size = self.buffer_size * self.num_env
        batch = jax.tree.map(
            lambda x: x.reshape((batch_size,) + x.shape[2:]), buffer
        )

        # shuffling
        permutation = jax.random.permutation(perm_key, batch_size)
        batch = jax.tree_util.tree_map(
            lambda x: jnp.take(x, permutation, axis=0), batch
        )

        return batch    

__init__(buffer_size, num_env)

Parameters:

Name Type Description Default
buffer_size

Number of timesteps in the rollout.

required
num_env

Number of parallel environments.

required
Source code in bordax/data/batchbuilders.py
41
42
43
44
45
46
47
48
def __init__(self, buffer_size, num_env):
    """
    Args:
        buffer_size: Number of timesteps in the rollout.
        num_env: Number of parallel environments.
    """
    self.buffer_size = buffer_size
    self.num_env = num_env

MiniBatch

Bases: BatchBuilder

Split a flat batch into equal-sized minibatches.

Reshapes the leading dimension into (num_minibatches, minibatch_size). The resulting array is iterated over by the updater's SGD loop.

Source code in bordax/data/batchbuilders.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class MiniBatch(BatchBuilder):
    """Split a flat batch into equal-sized minibatches.

    Reshapes the leading dimension into ``(num_minibatches, minibatch_size)``.
    The resulting array is iterated over by the updater's SGD loop.
    """

    def __init__(self, num_minibatches: int):
        """
        Args:
            num_minibatches: Number of minibatches to split the batch into.
                The batch size must be divisible by this value.
        """
        self.num_minibatches = num_minibatches

    def __call__(
        self, key: PRNGKey, buffer: Any
    ) -> Tuple[PRNGKey, Mapping[str, jnp.ndarray]]:

        minibatches = jax.tree_util.tree_map(
            lambda x: x.reshape((self.num_minibatches, -1) + x.shape[1:]), buffer
        )

        return minibatches

__init__(num_minibatches)

Parameters:

Name Type Description Default
num_minibatches int

Number of minibatches to split the batch into. The batch size must be divisible by this value.

required
Source code in bordax/data/batchbuilders.py
77
78
79
80
81
82
83
def __init__(self, num_minibatches: int):
    """
    Args:
        num_minibatches: Number of minibatches to split the batch into.
            The batch size must be divisible by this value.
    """
    self.num_minibatches = num_minibatches

NormalizeAdvantagesTargets

Bases: BatchBuilder

Normalizes advantages (and optionally value targets) per minibatch.

Source code in bordax/data/batchbuilders.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class NormalizeAdvantagesTargets(BatchBuilder):
    """Normalizes advantages (and optionally value targets) per minibatch."""

    def __init__(self, eps: float = 1e-8, normalize_targets: bool = True):
        """
        Args:
            eps: Small constant added to the standard deviation for numerical
                stability.
            normalize_targets: If ``True``, also normalise value targets in
                addition to advantages.
        """
        self.eps = eps
        self.normalize_targets = normalize_targets

    def __call__(self, key: PRNGKey, buffer: Any) -> Any:

        def normalize_minibatch(minibatch_data):
            advantages = minibatch_data["advantages"]
            adv_mean = jnp.mean(advantages)
            adv_std = jnp.std(advantages)
            normalized_advantages = (advantages - adv_mean) / (adv_std + self.eps)

            normalized_targets = minibatch_data["targets"]
            if self.normalize_targets:
                targets = minibatch_data["targets"]
                target_mean = jnp.mean(targets)
                target_std = jnp.std(targets)
                normalized_targets = (targets - target_mean) / (target_std + self.eps)

            return {
                **minibatch_data,
                "advantages": normalized_advantages,
                "targets": normalized_targets,
            }

        normalized_buffer = jax.vmap(normalize_minibatch)(buffer)

        return normalized_buffer

__init__(eps=1e-08, normalize_targets=True)

Parameters:

Name Type Description Default
eps float

Small constant added to the standard deviation for numerical stability.

1e-08
normalize_targets bool

If True, also normalise value targets in addition to advantages.

True
Source code in bordax/data/batchbuilders.py
 99
100
101
102
103
104
105
106
107
108
def __init__(self, eps: float = 1e-8, normalize_targets: bool = True):
    """
    Args:
        eps: Small constant added to the standard deviation for numerical
            stability.
        normalize_targets: If ``True``, also normalise value targets in
            addition to advantages.
    """
    self.eps = eps
    self.normalize_targets = normalize_targets

UniformReplayBatch

Bases: BatchBuilder

Sample a batch uniformly from a ReplayBuffer.

Source code in bordax/data/batchbuilders.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
class UniformReplayBatch(BatchBuilder):
    """Sample a batch uniformly from a ReplayBuffer."""

    def __init__(self, batch_size: int):
        """
        Args:
            batch_size: Number of transitions to sample per update.
        """
        self.batch_size = batch_size

    def __call__(self, key: PRNGKey, buffer: Any) -> Mapping[str, jnp.ndarray]:
        """Sample transitions from replay buffer and convert to JAX arrays.

        Args:
            key: PRNG key (unused, but kept for interface consistency)
            buffer: ReplayBuffer instance

        Returns:
            Dictionary of JAX arrays with keys: obs, action, reward, next_obs, done
        """
        # Sample from the numpy buffer
        batch_np = buffer.sample(self.batch_size)
        # Convert to JAX arrays
        batch_jax = jax.tree_util.tree_map(jnp.array, batch_np)
        return batch_jax

__call__(key, buffer)

Sample transitions from replay buffer and convert to JAX arrays.

Parameters:

Name Type Description Default
key PRNGKey

PRNG key (unused, but kept for interface consistency)

required
buffer Any

ReplayBuffer instance

required

Returns:

Type Description
Mapping[str, ndarray]

Dictionary of JAX arrays with keys: obs, action, reward, next_obs, done

Source code in bordax/data/batchbuilders.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def __call__(self, key: PRNGKey, buffer: Any) -> Mapping[str, jnp.ndarray]:
    """Sample transitions from replay buffer and convert to JAX arrays.

    Args:
        key: PRNG key (unused, but kept for interface consistency)
        buffer: ReplayBuffer instance

    Returns:
        Dictionary of JAX arrays with keys: obs, action, reward, next_obs, done
    """
    # Sample from the numpy buffer
    batch_np = buffer.sample(self.batch_size)
    # Convert to JAX arrays
    batch_jax = jax.tree_util.tree_map(jnp.array, batch_np)
    return batch_jax

__init__(batch_size)

Parameters:

Name Type Description Default
batch_size int

Number of transitions to sample per update.

required
Source code in bordax/data/batchbuilders.py
168
169
170
171
172
173
def __init__(self, batch_size: int):
    """
    Args:
        batch_size: Number of transitions to sample per update.
    """
    self.batch_size = batch_size

bordax.data.buffer

ReplayBuffer

A simple ring buffer for storing and sampling transitions for off-policy RL. This implementation is based on NumPy and is not designed to be JAX-jittable.

Source code in bordax/data/buffer.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class ReplayBuffer:
    """
    A simple ring buffer for storing and sampling transitions for off-policy RL.
    This implementation is based on NumPy and is not designed to be JAX-jittable.
    """

    def __init__(self, capacity: int, obs_shape: Tuple[int, ...], action_shape: Tuple[int, ...]):
        """
        Initializes the replay buffer.

        Args:
            capacity: The maximum number of transitions to store.
            obs_shape: The shape of a single observation.
            action_shape: The shape of a single action.
        """
        self.capacity = capacity
        self.obs_shape = obs_shape
        self.action_shape = action_shape

        self.observations = np.zeros((capacity, *obs_shape), dtype=np.float32)
        self.actions = np.zeros((capacity, *action_shape), dtype=np.int32)
        self.rewards = np.zeros(capacity, dtype=np.float32)
        self.next_observations = np.zeros((capacity, *obs_shape), dtype=np.float32)
        self.dones = np.zeros(capacity, dtype=np.bool_)

        self._ptr = 0
        self._size = 0

    def add(self, rollout: Dict[str, np.ndarray]):
        """
        Adds a batch of transitions to the buffer.
        The input arrays in the rollout dictionary are expected to have the same leading dimension.
        Required keys: 'obs', 'action', 'reward', 'next_obs', 'done'.
        """
        num_transitions = rollout['obs'].shape[0]
        indices = np.arange(self._ptr, self._ptr + num_transitions) % self.capacity

        self.observations[indices] = rollout['obs']
        self.actions[indices] = rollout['action']
        self.rewards[indices] = rollout['reward']
        self.next_observations[indices] = rollout['next_obs']
        self.dones[indices] = rollout['done']

        self._ptr = (self._ptr + num_transitions) % self.capacity
        self._size = min(self._size + num_transitions, self.capacity)

    def sample(self, batch_size: int) -> Dict[str, np.ndarray]:
        """
        Samples a batch of transitions from the buffer.

        Args:
            batch_size: The number of transitions to sample.

        Returns:
            A dictionary containing the sampled transitions.
        """
        if self._size < batch_size:
            raise ValueError(f"Not enough samples in the buffer to sample {batch_size} transitions. "
                             f"Current size: {self._size}")

        indices = np.random.randint(0, self._size, size=batch_size)
        return {
            'obs': self.observations[indices],
            'action': self.actions[indices],
            'reward': self.rewards[indices],
            'next_obs': self.next_observations[indices],
            'done': self.dones[indices],
        }

    def __len__(self) -> int:
        return self._size

__init__(capacity, obs_shape, action_shape)

Initializes the replay buffer.

Parameters:

Name Type Description Default
capacity int

The maximum number of transitions to store.

required
obs_shape Tuple[int, ...]

The shape of a single observation.

required
action_shape Tuple[int, ...]

The shape of a single action.

required
Source code in bordax/data/buffer.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def __init__(self, capacity: int, obs_shape: Tuple[int, ...], action_shape: Tuple[int, ...]):
    """
    Initializes the replay buffer.

    Args:
        capacity: The maximum number of transitions to store.
        obs_shape: The shape of a single observation.
        action_shape: The shape of a single action.
    """
    self.capacity = capacity
    self.obs_shape = obs_shape
    self.action_shape = action_shape

    self.observations = np.zeros((capacity, *obs_shape), dtype=np.float32)
    self.actions = np.zeros((capacity, *action_shape), dtype=np.int32)
    self.rewards = np.zeros(capacity, dtype=np.float32)
    self.next_observations = np.zeros((capacity, *obs_shape), dtype=np.float32)
    self.dones = np.zeros(capacity, dtype=np.bool_)

    self._ptr = 0
    self._size = 0

add(rollout)

Adds a batch of transitions to the buffer. The input arrays in the rollout dictionary are expected to have the same leading dimension. Required keys: 'obs', 'action', 'reward', 'next_obs', 'done'.

Source code in bordax/data/buffer.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def add(self, rollout: Dict[str, np.ndarray]):
    """
    Adds a batch of transitions to the buffer.
    The input arrays in the rollout dictionary are expected to have the same leading dimension.
    Required keys: 'obs', 'action', 'reward', 'next_obs', 'done'.
    """
    num_transitions = rollout['obs'].shape[0]
    indices = np.arange(self._ptr, self._ptr + num_transitions) % self.capacity

    self.observations[indices] = rollout['obs']
    self.actions[indices] = rollout['action']
    self.rewards[indices] = rollout['reward']
    self.next_observations[indices] = rollout['next_obs']
    self.dones[indices] = rollout['done']

    self._ptr = (self._ptr + num_transitions) % self.capacity
    self._size = min(self._size + num_transitions, self.capacity)

sample(batch_size)

Samples a batch of transitions from the buffer.

Parameters:

Name Type Description Default
batch_size int

The number of transitions to sample.

required

Returns:

Type Description
Dict[str, ndarray]

A dictionary containing the sampled transitions.

Source code in bordax/data/buffer.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def sample(self, batch_size: int) -> Dict[str, np.ndarray]:
    """
    Samples a batch of transitions from the buffer.

    Args:
        batch_size: The number of transitions to sample.

    Returns:
        A dictionary containing the sampled transitions.
    """
    if self._size < batch_size:
        raise ValueError(f"Not enough samples in the buffer to sample {batch_size} transitions. "
                         f"Current size: {self._size}")

    indices = np.random.randint(0, self._size, size=batch_size)
    return {
        'obs': self.observations[indices],
        'action': self.actions[indices],
        'reward': self.rewards[indices],
        'next_obs': self.next_observations[indices],
        'done': self.dones[indices],
    }