Shortcuts

循环 DQN:训练循环策略

Created On: Nov 08, 2023 | Last Updated: Jan 27, 2025 | Last Verified: Not Verified

作者Vincent Moens

What you will learn
  • 如何在 TorchRL 中将 RNN 集成到一个 actor 中

  • 如何将基于记忆的策略与回放缓冲区和损失模块结合使用

Prerequisites
  • PyTorch v2.0.0

  • gym[mujoco]

  • tqdm

概述

基于记忆的策略不仅在观察部分可观察时至关重要,而且在需要考虑时间维度以做出明智决策时也是如此。

循环神经网络长期以来一直是基于记忆的策略的热门工具。其核心思想是保留一个循环状态,在两个连续步骤之间存储在内存中,并将其作为策略的输入,与当前观察一起使用。

本教程展示了如何使用 TorchRL 将 RNN 集成到策略中。

关键学习点:

  • 在 TorchRL 中将 RNN 集成到一个 actor 中;

  • 将基于记忆的策略与回放缓冲区和损失模块结合使用。

在 TorchRL 中使用 RNN 的核心思想是使用 TensorDict 作为隐藏状态在步骤之间的载体。我们将构建一个策略,该策略从当前 TensorDict 读取先前的循环状态,并将当前循环状态写入下一状态的 TensorDict 中:

使用循环策略进行数据收集

如图所示,我们的环境用清零的循环状态填充 TensorDict,这些循环状态由策略与观察一起读取以生成操作,以及将用于下一步的循环状态。当调用 step_mdp() 函数时,下一状态的循环状态将被带入到当前 TensorDict 中。让我们看看如何在实际中实现这一点。

如果您在 Google Colab 中运行此程序,请确保安装以下依赖项:

!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm

设置

import torch
import tqdm
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.envs import (
    Compose,
    ExplorationType,
    GrayScale,
    InitTracker,
    ObservationNorm,
    Resize,
    RewardScaling,
    set_exploration_type,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule
from torchrl.objectives import DQNLoss, SoftUpdate

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

环境

通常,第一步是构建我们的环境:它帮助我们定义问题,并相应地构建策略网络。在本教程中,我们将运行一个基于像素的 CartPole gym 环境实例,并添加一些自定义变换:转换为灰度,调整大小为 84x84,缩减奖励并规范化观察。

备注

StepCounter 转换是附加的。由于 CartPole 任务的目标是使轨迹尽可能长,计数步骤可以帮助我们跟踪策略的性能。

本教程中有两个重要的变换:

  • InitTracker 将通过在 TensorDict 中添加一个 "is_init" 布尔掩码标记对 reset() 的调用,这个掩码将跟踪哪些步骤需要重置 RNN 隐藏状态。

  • TensorDictPrimer 转换更为技术性。它不是使用 RNN 策略的必需条件。但是,它指示环境(以及随后是采集器)预期某些附加的关键值。一旦添加,调用 env.reset() 会用清零的张量填充 primer 中指示的条目。因为这些张量是策略所需的,采集器将在采集过程中传递它们。最终,我们将在回放缓冲区中存储我们的隐藏状态,这将有助于在损失模块中对 RNN 操作进行引导计算(否则将以 0s 开始)。总之:不包括此转换对我们的策略训练影响不大,但会使循环键从收集的数据和回放缓冲区中消失,从而导致训练效果略有下降。幸运的是,我们提供的 LSTMModule 配备了一个帮助方法,可以为我们构建这一变换,因此我们可以等到构建它为止!

env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, device=device),
    Compose(
        ToTensorImage(),
        GrayScale(),
        Resize(84, 84),
        StepCounter(),
        InitTracker(),
        RewardScaling(loc=0.0, scale=0.1),
        ObservationNorm(standard_normal=True, in_keys=["pixels"]),
    ),
)

像往常一样,我们需要手动初始化我们的归一化常数:

env.transform[-1].init_stats(1000, reduce_dim=[0, 1, 2], cat_dim=0, keep_dims=[0])
td = env.reset()

策略

我们的策略将有三个组成部分:一个 ConvNet 主干,一个 LSTMModule 记忆层以及一个浅层 MLP 块,将 LSTM 输出映射到动作值上。

卷积网络

我们构建一个卷积网络,并在其两侧使用 torch.nn.AdaptiveAvgPool2d,它会将输出压缩为大小为 64 的向量。ConvNet 可以协助我们完成这个任务:

feature = Mod(
    ConvNet(
        num_cells=[32, 32, 64],
        squeeze_output=True,
        aggregator_class=nn.AdaptiveAvgPool2d,
        aggregator_kwargs={"output_size": (1, 1)},
        device=device,
    ),
    in_keys=["pixels"],
    out_keys=["embed"],
)

我们在一批数据上执行第一个模块以获取输出向量的大小:

n_cells = feature(env.reset())["embed"].shape[-1]

LSTM 模块

TorchRL 提供了一个专用的 LSTMModule 类,用于将 LSTM 整合到您的代码库中。它是 TensorDictModuleBase 的子类:因此,它有一组 in_keysout_keys,指示在模块执行期间应读取和写入/更新哪些值。该类为这些属性提供了可自定义的预定义值,以便于其构建。

备注

使用限制:该类支持几乎所有 LSTM 的功能,例如 dropout 或多层 LSTM。然而,为了符合 TorchRL 的惯例,这个 LSTM 必须将 batch_first 属性设置为 True,这在 PyTorch 中并非默认值。不过,我们的 LSTMModule 改变了这种默认行为,所以可以直接使用。

此外,LSTM不能将``bidirectional``属性设置为``True``,因为这在在线设置中将不可用。在这种情况下,默认值是正确的。

lstm = LSTMModule(
    input_size=n_cells,
    hidden_size=128,
    device=device,
    in_key="embed",
    out_key="embed",
)

让我们看看LSTM模块类,特别是它的in_keys和out_keys:

print("in_keys", lstm.in_keys)
print("out_keys", lstm.out_keys)
in_keys ['embed', 'recurrent_state_h', 'recurrent_state_c', 'is_init']
out_keys ['embed', ('next', 'recurrent_state_h'), ('next', 'recurrent_state_c')]

我们可以看到这些值包含了我们指定为in_key(和out_key)的键以及循环键名称。out_keys前缀为“next”,表明它们需要写入“next”TensorDict中。我们采用这种约定(可以通过传递in_keys/out_keys参数覆盖),以确保调用:func:`~torchrl.envs.utils.step_mdp`时,会将循环状态移动到根TensorDict,从而在后续调用中可供RNN使用(见介绍中的图)。

如前所述,为了确保循环状态传递到缓冲区,我们还有一个可选的转化需要添加到我们的环境中。:meth:`~torchrl.modules.LSTMModule.make_tensordict_primer`方法正是这样做的:

env.append_transform(lstm.make_tensordict_primer())
TransformedEnv(
    env=GymEnv(env=CartPole-v1, batch_size=torch.Size([]), device=cpu),
    transform=Compose(
            ToTensorImage(keys=['pixels']),
            GrayScale(keys=['pixels']),
            Resize(w=84, h=84, interpolation=InterpolationMode.BILINEAR, keys=['pixels']),
            StepCounter(keys=[]),
            InitTracker(keys=[]),
            RewardScaling(loc=0.0000, scale=0.1000, keys=['reward']),
            ObservationNorm(keys=['pixels']),
            TensorDictPrimer(primers=Composite(
                recurrent_state_h: UnboundedContinuous(
                    shape=torch.Size([1, 128]),
                    space=ContinuousBox(
                        low=Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, contiguous=True),
                        high=Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, contiguous=True)),
                    device=cpu,
                    dtype=torch.float32,
                    domain=continuous),
                recurrent_state_c: UnboundedContinuous(
                    shape=torch.Size([1, 128]),
                    space=ContinuousBox(
                        low=Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, contiguous=True),
                        high=Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, contiguous=True)),
                    device=cpu,
                    dtype=torch.float32,
                    domain=continuous),
                device=cpu,
                shape=torch.Size([])), default_value={'recurrent_state_h': 0.0, 'recurrent_state_c': 0.0}, random=None)))

这样就完成了!我们可以打印环境以检查添加primer后的一切是否正常:

print(env)
TransformedEnv(
    env=GymEnv(env=CartPole-v1, batch_size=torch.Size([]), device=cpu),
    transform=Compose(
            ToTensorImage(keys=['pixels']),
            GrayScale(keys=['pixels']),
            Resize(w=84, h=84, interpolation=InterpolationMode.BILINEAR, keys=['pixels']),
            StepCounter(keys=[]),
            InitTracker(keys=[]),
            RewardScaling(loc=0.0000, scale=0.1000, keys=['reward']),
            ObservationNorm(keys=['pixels']),
            TensorDictPrimer(primers=Composite(
                recurrent_state_h: UnboundedContinuous(
                    shape=torch.Size([1, 128]),
                    space=ContinuousBox(
                        low=Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, contiguous=True),
                        high=Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, contiguous=True)),
                    device=cpu,
                    dtype=torch.float32,
                    domain=continuous),
                recurrent_state_c: UnboundedContinuous(
                    shape=torch.Size([1, 128]),
                    space=ContinuousBox(
                        low=Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, contiguous=True),
                        high=Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, contiguous=True)),
                    device=cpu,
                    dtype=torch.float32,
                    domain=continuous),
                device=cpu,
                shape=torch.Size([])), default_value={'recurrent_state_h': 0.0, 'recurrent_state_c': 0.0}, random=None)))

MLP

我们使用一个单层MLP来表示我们用于策略的动作值。

mlp = MLP(
    out_features=2,
    num_cells=[
        64,
    ],
    device=device,
)

并用零填充偏置:

mlp[-1].bias.data.fill_(0.0)
mlp = Mod(mlp, in_keys=["embed"], out_keys=["action_value"])

使用Q值选择动作

我们策略的最后一部分是Q值模块。Q值模块:class:~torchrl.modules.tensordict_module.QValueModule`将读取由我们的MLP生成的`”action_values”``键,并从中选择值最大的动作。我们只需指定动作空间即可,这可以通过传递字符串或动作规格来完成。这样可以使用Categorical(有时称为”稀疏”)编码或其独热形式的版本。

qval = QValueModule(spec=env.action_spec)

备注

TorchRL还提供了一个包装类:class:torchrl.modules.QValueActor,它将一个模块与:class:`~torchrl.modules.tensordict_module.QValueModule`模块一起以Sequential的形式封装。这样做优势不大且过程不够透明,但结果与我们这里的实现相似。

我们现在可以使用:class:`~tensordict.nn.TensorDictSequential`将这些组件组合起来。

stoch_policy = Seq(feature, lstm, mlp, qval)

由于DQN是一个确定性算法,探索是其关键部分。我们将使用一个epsilon-贪婪策略,其中epsilon从0.2逐步递减到0。这种衰减是通过调用:meth:`~torchrl.modules.EGreedyModule.step`实现的(见下方训练循环)。

exploration_module = EGreedyModule(
    annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2
)
stoch_policy = Seq(
    stoch_policy,
    exploration_module,
)

使用模型进行损失计算

我们构建的模型在顺序设置中可以很好地工作。然而,类:class:`torch.nn.LSTM`可以使用cuDNN优化后的后端在GPU设备上更快地运行RNN序列。我们绝不希望错过这样一个加速训练循环的机会!为了使用它,我们只需要告诉LSTM模块在通过损失使用时以“循环模式”运行。由于我们通常需要两个LSTM模块实例,我们通过调用:meth:`~torchrl.modules.LSTMModule.set_recurrent_mode`方法,该方法将返回一个新的LSTM实例(共享权重),并假设输入数据是顺序性质的。

policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval)
/data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torchrl/modules/tensordict_module/rnn.py:710: DeprecationWarning:

The lstm.set_recurrent_mode() API is deprecated and will be removed in v0.8. To set the recurent mode, use the :class:`~torchrl.modules.set_recurrent_mode` context manager or the `default_recurrent_mode` keyword argument in the constructor.

因为我们仍然有一些未初始化的参数,所以在创建优化器等之前应该先初始化它们。

policy(env.reset())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        embed: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                recurrent_state_c: Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
                recurrent_state_h: Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),
        recurrent_state_c: Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        recurrent_state_h: Tensor(shape=torch.Size([1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

DQN损失

我们的DQN损失需要我们传递策略和动作空间。虽然这看起来有些冗余,但很重要,因为我们必须确保:class:`~torchrl.objectives.DQNLoss`和:class:`~torchrl.modules.tensordict_module.QValueModule`类是兼容的,但不是强依赖。

为了使用双DQN,我们提供``delay_value``参数,该参数将创建网络参数的非可微副本,用作目标网络。

loss_fn = DQNLoss(policy, action_space=env.action_spec, delay_value=True)

由于我们使用双DQN,因此需要更新目标参数。我们将使用:class:`~torchrl.objectives.SoftUpdate`实例来完成这项工作。

updater = SoftUpdate(loss_fn, eps=0.95)

optim = torch.optim.Adam(policy.parameters(), lr=3e-4)

收集器和回放缓冲区

我们构建了最简单的数据收集器。我们尝试用一百万帧训练算法,每次扩展缓冲区50帧。该缓冲区被设计为存储20000条50步的轨迹。在每次优化步骤(每次数据收集16次)中,我们从缓冲区中收集4项,总计200次转换。我们将使用:class:`~torchrl.data.replay_buffers.LazyMemmapStorage`存储器在磁盘上保存数据。

备注

为了提高效率,这里我们只运行了一几千次迭代。在实际设置中,帧总数应设置为1M。

collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)
rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
)

训练循环

为了跟踪进展,我们每进行50次数据收集就让策略运行一次,并在训练后绘制结果。

utd = 16
pbar = tqdm.tqdm(total=1_000_000)
longest = 0

traj_lens = []
for i, data in enumerate(collector):
    if i == 0:
        print(
            "Let us print the first batch of data.\nPay attention to the key names "
            "which will reflect what can be found in this data structure, in particular: "
            "the output of the QValueModule (action_values, action and chosen_action_value),"
            "the 'is_init' key that will tell us if a step is initial or not, and the "
            "recurrent_state keys.\n",
            data,
        )
    pbar.update(data.numel())
    # it is important to pass data that is not flattened
    rb.extend(data.unsqueeze(0).to_tensordict().cpu())
    for _ in range(utd):
        s = rb.sample().to(device, non_blocking=True)
        loss_vals = loss_fn(s)
        loss_vals["loss"].backward()
        optim.step()
        optim.zero_grad()
    longest = max(longest, data["step_count"].max().item())
    pbar.set_description(
        f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}"
    )
    exploration_module.step(data.numel())
    updater.step()

    with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
        rollout = env.rollout(10000, stoch_policy)
        traj_lens.append(rollout.get(("next", "step_count")).max().item())
  0%|          | 0/1000000 [00:00<?, ?it/s]Let us print the first batch of data.
Pay attention to the key names which will reflect what can be found in this data structure, in particular: the output of the QValueModule (action_values, action and chosen_action_value),the 'is_init' key that will tell us if a step is initial or not, and the recurrent_state keys.
 TensorDict(
    fields={
        action: Tensor(shape=torch.Size([50, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([50, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([50]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([50]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        embed: Tensor(shape=torch.Size([50, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        is_init: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([50, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),
                recurrent_state_c: Tensor(shape=torch.Size([50, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
                recurrent_state_h: Tensor(shape=torch.Size([50, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                terminated: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([50]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([50, 1, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False),
        recurrent_state_c: Tensor(shape=torch.Size([50, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        recurrent_state_h: Tensor(shape=torch.Size([50, 1, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([50, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([50]),
    device=cpu,
    is_shared=False)

  0%|          | 50/1000000 [00:00<3:27:33, 80.30it/s]
  0%|          | 50/1000000 [00:20<3:27:33, 80.30it/s]
steps: 13, loss_val:  0.0002, action_spread: tensor([ 8, 42]):   0%|          | 50/1000000 [00:37<3:27:33, 80.30it/s]
steps: 13, loss_val:  0.0002, action_spread: tensor([ 8, 42]):   0%|          | 100/1000000 [00:38<125:14:11,  2.22it/s]
steps: 27, loss_val:  0.0002, action_spread: tensor([36, 14]):   0%|          | 100/1000000 [01:17<125:14:11,  2.22it/s]
steps: 27, loss_val:  0.0002, action_spread: tensor([36, 14]):   0%|          | 150/1000000 [01:18<169:50:23,  1.64it/s]
steps: 27, loss_val:  0.0003, action_spread: tensor([ 7, 43]):   0%|          | 150/1000000 [01:57<169:50:23,  1.64it/s]
steps: 27, loss_val:  0.0003, action_spread: tensor([ 7, 43]):   0%|          | 200/1000000 [01:58<190:36:42,  1.46it/s]
steps: 27, loss_val:  0.0003, action_spread: tensor([40, 10]):   0%|          | 200/1000000 [02:37<190:36:42,  1.46it/s]

让我们绘制结果:

if traj_lens:
    from matplotlib import pyplot as plt

    plt.plot(traj_lens)
    plt.xlabel("Test collection")
    plt.title("Test trajectory lengths")
Test trajectory lengths

结论

我们已经看到如何在TorchRL中将RNN融合到策略中。现在您应该能够:

  • 创建一个充当:class:`~tensordict.nn.TensorDictModule`的LSTM模块。

  • 通过:class:`~torchrl.envs.transforms.InitTracker`转化向LSTM模块指示需要重置。

  • 将此模块集成到策略中以及损失模块中。

  • 确保收集器知道循环状态条目,使它们能够与数据的其余部分一起存储在回放缓冲区中。

进一步阅读

Total running time of the script: ( 2 minutes 43.613 seconds)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源