Skip to content

bordax.agents

bordax.agents.base

Agent base classes and simple policy/value MLP implementations.

This module defines the Agent abstract base class and several concrete agents and neural modules used by the project:

  • Agent: abstract interface for agents (init, policy, action, value).
  • BlankAgent: a simple uniform (random) discrete action agent.
  • MLP / MLP_dtsemnet / MLP_boolean: small neural modules used as policy architectures.
  • MLPPolicyValue / MLPPolicyValueContinuous: actor-critic wrappers that expose a policy (Categorical or Normal) and a value function.

Docstrings are provided for classes and public methods to aid reading and automatic documentation generation.

Agent

Bases: ABC

Abstract base class for all agents.

Subclasses must implement init and policy. The action method is provided as a JIT-compiled convenience wrapper around policy. Override value if the agent supports a value function (required for actor-critic algorithms such as PPO).

Source code in bordax/agents/base.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class Agent(ABC):
    """Abstract base class for all agents.

    Subclasses must implement ``init`` and ``policy``. The ``action``
    method is provided as a JIT-compiled convenience wrapper around
    ``policy``. Override ``value`` if the agent supports a value function
    (required for actor-critic algorithms such as PPO).
    """

    @abstractmethod
    def init(self, key: PRNGKey, sample_obs: Any) -> AgentParameters:
        """Initialise network parameters.

        Args:
            key: JAX random key for weight initialisation.
            sample_obs: A sample observation with the correct shape
                (including the ``num_envs`` batch dimension).

        Returns:
            An ``AgentParameters`` pytree (e.g. ``PolicyValueParameters``
            or ``DQNParameters``).
        """
        ...

    @abstractmethod
    def policy(
        self, params: AgentParameters, obs: Any, key: PRNGKey
    ) -> Tuple[DistributionLike, Mapping[str, Any]]:
        """Compute the policy distribution for a batch of observations.

        Args:
            params: Current network parameters.
            obs: Batch of observations, shape ``(num_envs, *obs_shape)``.
            key: JAX random key (for stochastic policy heads).

        Returns:
            Tuple of ``(distribution, info)`` where ``distribution`` is a
            Distrax distribution and ``info`` is a dict of auxiliary data.
        """
        ...

    @functools.partial(jax.jit, static_argnames=("self", "is_deterministic"))
    def action(
        self, params: AgentParameters, obs: Any, key: PRNGKey, is_deterministic=False
    ) -> Tuple[DistributionLike, Mapping[str, Any]]:
        """Sample or select an action from the policy distribution."""
        policy_key, sample_key = jax.random.split(key)
        dist, info = self.policy(params, obs, policy_key)
        if is_deterministic:
            action = dist.mode()
            logp = dist.log_prob(action)
        else:
            action, logp = dist.sample_and_log_prob(seed=sample_key)
            if isinstance(logp, jnp.ndarray) and logp.ndim > 1:
                logp = jnp.sum(logp, axis=-1)
        return action, dict(
            logp=logp,
            **info,
        )

    def value(self, params: Params, obs: Any) -> jnp.ndarray:
        """Compute the value estimate for a batch of observations.

        Args:
            params: Current network parameters.
            obs: Batch of observations, shape ``(num_envs, *obs_shape)``.

        Returns:
            Value estimates, shape ``(num_envs,)``.

        Raises:
            NotImplementedError: If the agent has no value function.
        """
        raise NotImplementedError

action(params, obs, key, is_deterministic=False)

Sample or select an action from the policy distribution.

Source code in bordax/agents/base.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@functools.partial(jax.jit, static_argnames=("self", "is_deterministic"))
def action(
    self, params: AgentParameters, obs: Any, key: PRNGKey, is_deterministic=False
) -> Tuple[DistributionLike, Mapping[str, Any]]:
    """Sample or select an action from the policy distribution."""
    policy_key, sample_key = jax.random.split(key)
    dist, info = self.policy(params, obs, policy_key)
    if is_deterministic:
        action = dist.mode()
        logp = dist.log_prob(action)
    else:
        action, logp = dist.sample_and_log_prob(seed=sample_key)
        if isinstance(logp, jnp.ndarray) and logp.ndim > 1:
            logp = jnp.sum(logp, axis=-1)
    return action, dict(
        logp=logp,
        **info,
    )

init(key, sample_obs) abstractmethod

Initialise network parameters.

Parameters:

Name Type Description Default
key PRNGKey

JAX random key for weight initialisation.

required
sample_obs Any

A sample observation with the correct shape (including the num_envs batch dimension).

required

Returns:

Type Description
AgentParameters

An AgentParameters pytree (e.g. PolicyValueParameters

AgentParameters

or DQNParameters).

Source code in bordax/agents/base.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@abstractmethod
def init(self, key: PRNGKey, sample_obs: Any) -> AgentParameters:
    """Initialise network parameters.

    Args:
        key: JAX random key for weight initialisation.
        sample_obs: A sample observation with the correct shape
            (including the ``num_envs`` batch dimension).

    Returns:
        An ``AgentParameters`` pytree (e.g. ``PolicyValueParameters``
        or ``DQNParameters``).
    """
    ...

policy(params, obs, key) abstractmethod

Compute the policy distribution for a batch of observations.

Parameters:

Name Type Description Default
params AgentParameters

Current network parameters.

required
obs Any

Batch of observations, shape (num_envs, *obs_shape).

required
key PRNGKey

JAX random key (for stochastic policy heads).

required

Returns:

Type Description
DistributionLike

Tuple of (distribution, info) where distribution is a

Mapping[str, Any]

Distrax distribution and info is a dict of auxiliary data.

Source code in bordax/agents/base.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@abstractmethod
def policy(
    self, params: AgentParameters, obs: Any, key: PRNGKey
) -> Tuple[DistributionLike, Mapping[str, Any]]:
    """Compute the policy distribution for a batch of observations.

    Args:
        params: Current network parameters.
        obs: Batch of observations, shape ``(num_envs, *obs_shape)``.
        key: JAX random key (for stochastic policy heads).

    Returns:
        Tuple of ``(distribution, info)`` where ``distribution`` is a
        Distrax distribution and ``info`` is a dict of auxiliary data.
    """
    ...

value(params, obs)

Compute the value estimate for a batch of observations.

Parameters:

Name Type Description Default
params Params

Current network parameters.

required
obs Any

Batch of observations, shape (num_envs, *obs_shape).

required

Returns:

Type Description
ndarray

Value estimates, shape (num_envs,).

Raises:

Type Description
NotImplementedError

If the agent has no value function.

Source code in bordax/agents/base.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def value(self, params: Params, obs: Any) -> jnp.ndarray:
    """Compute the value estimate for a batch of observations.

    Args:
        params: Current network parameters.
        obs: Batch of observations, shape ``(num_envs, *obs_shape)``.

    Returns:
        Value estimates, shape ``(num_envs,)``.

    Raises:
        NotImplementedError: If the agent has no value function.
    """
    raise NotImplementedError

BlankAgent

Bases: Agent

A trivial agent that returns a uniform categorical policy.

Source code in bordax/agents/base.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class BlankAgent(Agent):
    """A trivial agent that returns a uniform categorical policy."""

    def __init__(self, env: EnvAdapter):
        self.action_space = env.action_space()
        if not hasattr(self.action_space, "n"):
            raise ValueError("BlankAgent only supports discrete action spaces.")
        self.batch_dim = None

    def init(self, key: PRNGKey, sample_obs: Any) -> Params:
        self.batch_dim = sample_obs.shape[0]
        return {}

    def policy(self, params: Params, obs: Any, key: PRNGKey) -> Tuple[Any, Mapping[str, Any]]:
        """Return a uniform categorical distribution over actions."""
        pi = Categorical(logits=jnp.ones((self.batch_dim,) + (self.action_space.n,)))
        return pi, {}

    def value(self, params: Params, obs: Any) -> jnp.ndarray:
        return jnp.zeros(obs.shape[:-1])

policy(params, obs, key)

Return a uniform categorical distribution over actions.

Source code in bordax/agents/base.py
120
121
122
123
def policy(self, params: Params, obs: Any, key: PRNGKey) -> Tuple[Any, Mapping[str, Any]]:
    """Return a uniform categorical distribution over actions."""
    pi = Categorical(logits=jnp.ones((self.batch_dim,) + (self.action_space.n,)))
    return pi, {}

DQNAgent

Bases: Agent

A DQN agent with a Q-network and target network.

Source code in bordax/agents/base.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
class DQNAgent(Agent):
    """A DQN agent with a Q-network and target network."""

    def __init__(self, config: dict, env: EnvAdapter):
        self.config = config
        action_space = env.action_space()
        if not hasattr(action_space, "n"):
            raise ValueError("DQNAgent only supports discrete action spaces.")
        self.n_actions = action_space.n
        self.q_network = MLP(layer_sizes=self.config["q_layers"] + [self.n_actions])
        self.target_network = MLP(layer_sizes=self.config["q_layers"] + [self.n_actions])

    def init(self, key: PRNGKey, sample_obs: Any) -> DQNParameters:
        q_params = self.q_network.init(key, sample_obs)
        return DQNParameters(q_network=q_params, target_network=q_params)

    @functools.partial(jax.jit, static_argnames=("self"))
    def policy(self, params: DQNParameters, obs: Any, key: PRNGKey) -> Tuple[Any, Mapping[str, Any]]:
        q_values = self.q_network.apply(params.q_network, obs)
        if isinstance(q_values, tuple):
            q_values = q_values[0]
        pi = Categorical(logits=q_values)
        return pi, {}    

    @functools.partial(jax.jit, static_argnames=("self"))
    def value(self, params: DQNParameters, obs: Any) -> jnp.ndarray:
        q_values = self.q_network.apply(params.q_network, obs)
        if isinstance(q_values, tuple):
            q_values = q_values[0]
        max_q_values = jnp.max(q_values, axis=-1)
        return max_q_values

MLPPolicyValue

Bases: Agent

Actor-critic wrapper for discrete actions.

Source code in bordax/agents/base.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class MLPPolicyValue(Agent):
    """Actor-critic wrapper for discrete actions."""

    def __init__(self, config: dict, env: EnvAdapter, policy_architecture: str):
        self.config = config
        action_space = env.action_space()
        if policy_architecture == "mlp":
            self.policy_module = MLP(
                layer_sizes=self.config["policy_layers"] + [action_space.n]
            )
        elif policy_architecture == "dt":
            self.policy_module = MLP_dtsemnet(
                tree_depth=self.config["tree_depth"], action_dim=action_space.n
            )
        elif policy_architecture == "bool":
            self.policy_module = MLP_boolean(
                n=self.config["n"], action_dim=action_space.n
            )
        else:
            raise ValueError(f"Unknown policy architecture: {policy_architecture}")

        self.value_module = MLP(layer_sizes=self.config["value_layers"] + [1])

    def init(self, key: PRNGKey, sample_obs: Any) -> PolicyValueParameters:
        policy_key, value_key = jax.random.split(key, 2)
        policy_params = self.policy_module.init(policy_key, sample_obs)
        value_params = self.value_module.init(value_key, sample_obs)
        return PolicyValueParameters(policy=policy_params, value=value_params)

    @functools.partial(jax.jit, static_argnames=("self"))
    def policy(self, params: PolicyValueParameters, obs: Any, key: PRNGKey) -> Tuple[Any, Mapping[str, Any]]:
        logits = self.policy_module.apply(params.policy, obs)
        if isinstance(logits, tuple):
            logits = logits[0]
        pi = Categorical(logits=logits)
        return pi, {}

    @functools.partial(jax.jit, static_argnames=("self"))
    def value(self, params: PolicyValueParameters, obs: Any) -> jnp.ndarray:
        value_out = self.value_module.apply(params.value, obs)
        if isinstance(value_out, tuple):
            value_out = value_out[0]
        return jnp.squeeze(value_out, axis=-1)

MLPPolicyValueContinuous

Bases: Agent

Actor-critic wrapper for continuous actions.

Source code in bordax/agents/base.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class MLPPolicyValueContinuous(Agent):
    """Actor-critic wrapper for continuous actions."""
    def __init__(self, config: dict, env: EnvAdapter, policy_architecture: str):
        self.config = config
        self.n_actions = env.action_space().shape[0]
        if policy_architecture == "mlp":
            self.policy_module = MLP(
                layer_sizes=self.config["policy_layers"] + [2 * self.n_actions]
            )
        else:
            raise ValueError(f"Unknown policy architecture: {policy_architecture}")
        self.value_module = MLP(layer_sizes=self.config["value_layers"] + [1])

    def init(self, key: PRNGKey, sample_obs: Any) -> Params:
        policy_key, value_key = jax.random.split(key, 2)
        policy_params = self.policy_module.init(policy_key, sample_obs)
        value_params = self.value_module.init(value_key, sample_obs)
        return {"policy": policy_params, "value": value_params}

    @functools.partial(jax.jit, static_argnames=("self"))
    def policy(self, params: PolicyValueParameters, obs: Any, key: PRNGKey) -> Tuple[Any, Mapping[str, Any]]:
        distribution_parameters = self.policy_module.apply(params.policy, obs)
        if isinstance(distribution_parameters, tuple):
            distribution_parameters = distribution_parameters[0]
        pi = Normal(
            loc=distribution_parameters[..., : self.n_actions],
            scale=jax.nn.softplus(distribution_parameters[..., self.n_actions :]),
        )
        return pi, {}

    @functools.partial(jax.jit, static_argnames=("self"))
    def value(self, params: PolicyValueParameters, obs: Any) -> jnp.ndarray:
        value_out = self.value_module.apply(params.value, obs)
        if isinstance(value_out, tuple):
            value_out = value_out[0]
        return jnp.squeeze(value_out, axis=-1)

bordax.agents.components

MLP

Bases: Module

Source code in bordax/agents/components.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class MLP(nn.Module):
    layer_sizes: List[int]
    """Simple fully-connected MLP used for policy/value heads.

    The module constructs a sequence of Dense layers using `layer_sizes`.
    The final layer is returned without an activation.
    """

    def setup(self):
        self.dense_layers = [
            nn.Dense(size, kernel_init=nn.initializers.orthogonal())
            for size in self.layer_sizes
        ]

    def __call__(self, x):
        for layer in self.dense_layers[:-1]:
            x = layer(x)
            x = nn.relu(x)
        return self.dense_layers[-1](x)

layer_sizes instance-attribute

Simple fully-connected MLP used for policy/value heads.

The module constructs a sequence of Dense layers using layer_sizes. The final layer is returned without an activation.

MLP_boolean

Bases: Module

Source code in bordax/agents/components.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class MLP_boolean(nn.Module):
    n: int
    action_dim: int
    """Boolean-function-inspired dense module.

    The module constructs a mapping from inputs to outputs by interpreting
    the learned dense layer as coefficients over the truth table of all
    boolean functions with `n` inputs. The outputs are reduced per
    `action_dim` using a max operation.
    """

    def setup(self):
        self.weights = nn.Dense(
            self.n,
            kernel_init=nn.initializers.orthogonal(),
            bias_init=nn.initializers.uniform(),
        )

    def __call__(self, x):

        if len(x.shape) == 1:
            x = jnp.array([x])

        x = self.weights(x)

        numbers = np.arange(2**self.n)

        binary_strings = [np.binary_repr(num, width=self.n) for num in numbers]

        function_representation = np.array(
            [[1 if char == "1" else -1 for char in binary] for binary in binary_strings]
        )
        function_representation = jnp.array(function_representation)

        x = x @ function_representation.T

        x = x.reshape((x.shape[0], -1, self.action_dim))
        x = x.max(axis=1)

        return x

action_dim instance-attribute

Boolean-function-inspired dense module.

The module constructs a mapping from inputs to outputs by interpreting the learned dense layer as coefficients over the truth table of all boolean functions with n inputs. The outputs are reduced per action_dim using a max operation.

MLP_dtsemnet

Bases: Module

Source code in bordax/agents/components.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class MLP_dtsemnet(nn.Module):
    tree_depth: int
    action_dim: int
    """A decision-tree-inspired dense module.

    This module builds an internal representation derived from a binary
    tree structure of depth `tree_depth` and maps inputs to `action_dim`
    outputs. It is an experimental architecture used as an alternative
    policy head.
    """

    def setup(self):
        self.weights = nn.Dense(
            (2 ** (self.tree_depth) - 1),
            kernel_init=nn.initializers.orthogonal(),
            bias_init=nn.initializers.uniform(),
        )

    def __call__(self, x):
        """Compute the forward pass for the tree-based representation.

        The implementation supports both single-example inputs (1D) and
        batched inputs (2D). Returns an array shaped (batch, action_dim).
        """

        if len(x.shape) == 1:
            x = jnp.array([x])

        x = self.weights(x)

        n_nodes = 2 ** (self.tree_depth) - 1
        n_leaves = n_nodes + 1

        row_indices = jnp.arange(2 * n_nodes)
        col_indices = jnp.arange(n_nodes).repeat(2)
        tiles = jnp.tile(jnp.array([1.0, -1.0]), n_nodes)
        matrix = jnp.zeros((2 * n_nodes, n_nodes), dtype=jnp.float32)
        matrix = matrix.at[row_indices, col_indices].set(tiles)

        x = nn.relu(x @ matrix.T)

        tree_representation = jnp.ones((n_leaves, 2 * n_nodes))
        for i in range(n_leaves):
            virtual_index = i + n_nodes
            relevant_indices = jnp.zeros(self.tree_depth - 1)
            replacement = jnp.ones(2 * n_nodes)
            for j in range(self.tree_depth):
                new_virtual_index = (virtual_index - 1) // 2
                relevant_indices = relevant_indices.at[self.tree_depth - j].set(
                    new_virtual_index
                )
                if virtual_index % 2 == 0:
                    replacement_tile = jnp.array([0, 1])
                else:
                    replacement_tile = jnp.array([1, 0])
                virtual_index = new_virtual_index
                replacement = replacement.at[
                    2 * virtual_index : 2 * virtual_index + 2
                ].set(replacement_tile)
            tree_representation = tree_representation.at[i].set(replacement)

        appendice = jnp.zeros(
            ((self.action_dim - (n_leaves % self.action_dim)), 2 * n_nodes)
        )
        tree_representation = jnp.concatenate((tree_representation, appendice), axis=0)

        x = x @ tree_representation.T

        x = x.reshape((x.shape[0], -1, self.action_dim))
        x = x.max(axis=1)

        return x

action_dim instance-attribute

A decision-tree-inspired dense module.

This module builds an internal representation derived from a binary tree structure of depth tree_depth and maps inputs to action_dim outputs. It is an experimental architecture used as an alternative policy head.

__call__(x)

Compute the forward pass for the tree-based representation.

The implementation supports both single-example inputs (1D) and batched inputs (2D). Returns an array shaped (batch, action_dim).

Source code in bordax/agents/components.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __call__(self, x):
    """Compute the forward pass for the tree-based representation.

    The implementation supports both single-example inputs (1D) and
    batched inputs (2D). Returns an array shaped (batch, action_dim).
    """

    if len(x.shape) == 1:
        x = jnp.array([x])

    x = self.weights(x)

    n_nodes = 2 ** (self.tree_depth) - 1
    n_leaves = n_nodes + 1

    row_indices = jnp.arange(2 * n_nodes)
    col_indices = jnp.arange(n_nodes).repeat(2)
    tiles = jnp.tile(jnp.array([1.0, -1.0]), n_nodes)
    matrix = jnp.zeros((2 * n_nodes, n_nodes), dtype=jnp.float32)
    matrix = matrix.at[row_indices, col_indices].set(tiles)

    x = nn.relu(x @ matrix.T)

    tree_representation = jnp.ones((n_leaves, 2 * n_nodes))
    for i in range(n_leaves):
        virtual_index = i + n_nodes
        relevant_indices = jnp.zeros(self.tree_depth - 1)
        replacement = jnp.ones(2 * n_nodes)
        for j in range(self.tree_depth):
            new_virtual_index = (virtual_index - 1) // 2
            relevant_indices = relevant_indices.at[self.tree_depth - j].set(
                new_virtual_index
            )
            if virtual_index % 2 == 0:
                replacement_tile = jnp.array([0, 1])
            else:
                replacement_tile = jnp.array([1, 0])
            virtual_index = new_virtual_index
            replacement = replacement.at[
                2 * virtual_index : 2 * virtual_index + 2
            ].set(replacement_tile)
        tree_representation = tree_representation.at[i].set(replacement)

    appendice = jnp.zeros(
        ((self.action_dim - (n_leaves % self.action_dim)), 2 * n_nodes)
    )
    tree_representation = jnp.concatenate((tree_representation, appendice), axis=0)

    x = x @ tree_representation.T

    x = x.reshape((x.shape[0], -1, self.action_dim))
    x = x.max(axis=1)

    return x

bordax.agents.utils

make_agent(agent_name, env, agent_config={})

Create an agent by name.

Parameters:

Name Type Description Default
agent_name str

Identifier in the form "policy/value". Supported values:

  • "mlp/mlp" — MLP policy with MLP value (discrete actions)
  • "mlp/dt" — DTSemNet decision-tree policy with MLP value
  • "mlp/bool" — HyperBool boolean policy with MLP value
  • "mlp-continuous/mlp" — MLP policy with MLP value (continuous actions)
  • "dqn/mlp" — DQN Q-network agent
  • "blank/blank" — uniform random agent (baseline)
required
env EnvAdapter

Environment adapter used to infer observation and action spaces.

required
agent_config dict

Dict of hyperparameters passed to the agent constructor. Required keys depend on the agent type (e.g. policy_layers, value_layers for MLP agents; q_layers for DQN).

{}

Returns:

Type Description
Agent

An initialised Agent instance.

Raises:

Type Description
ValueError

If agent_name is not in the registry.

Source code in bordax/agents/utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def make_agent(agent_name: str, env: EnvAdapter, agent_config: dict = {}) -> Agent:
    """Create an agent by name.

    Args:
        agent_name: Identifier in the form ``"policy/value"``. Supported values:

            - ``"mlp/mlp"`` — MLP policy with MLP value (discrete actions)
            - ``"mlp/dt"`` — DTSemNet decision-tree policy with MLP value
            - ``"mlp/bool"`` — HyperBool boolean policy with MLP value
            - ``"mlp-continuous/mlp"`` — MLP policy with MLP value (continuous actions)
            - ``"dqn/mlp"`` — DQN Q-network agent
            - ``"blank/blank"`` — uniform random agent (baseline)

        env: Environment adapter used to infer observation and action spaces.
        agent_config: Dict of hyperparameters passed to the agent constructor.
            Required keys depend on the agent type (e.g. ``policy_layers``,
            ``value_layers`` for MLP agents; ``q_layers`` for DQN).

    Returns:
        An initialised ``Agent`` instance.

    Raises:
        ValueError: If ``agent_name`` is not in the registry.
    """
    try:
        cls = AGENT_REGISTRY[agent_name]
    except KeyError:
        raise ValueError(f"Agent {agent_name} is not supported. Supported agents are: {list(AGENT_REGISTRY.keys())}")

    # DQNAgent doesn't need the policy_architecture parameter
    if agent_name.startswith("dqn/"):
        return cls(agent_config, env)

    return cls(agent_config, env, agent_name.split('/')[1])