PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

使用 TorchRL 进行强化学习 (PPO) 教程

作者: Vincent Moens

本教程将演示如何使用 PyTorch 和 torchrl 来训练一个参数化策略网络,以解决来自 OpenAI-Gym/Farama-Gymnasium 控制库 的倒立摆任务。

倒立摆

倒立摆

主要学习内容:

  • 如何在 TorchRL 中创建环境,转换其输出,并从该环境中收集数据;

  • 如何使用 TensorDict 让您的类之间进行通信

  • 使用 TorchRL 构建训练循环的基础知识:

    • 如何计算策略梯度方法中的优势信号;

    • 如何使用概率神经网络创建随机策略;

    • 如何创建动态经验回放缓冲区并从中无重复地采样。

我们将涵盖 TorchRL 的六个关键组件:

如果您在 Google Colab 中运行此代码,请确保安装以下依赖项:

!pip3installtorchrl
!pip3installgym[mujoco]
!pip3installtqdm

近端策略优化(Proximal Policy Optimization,PPO)是一种策略梯度算法,它通过收集一批数据并直接用于训练策略,以在给定的近端性约束下最大化预期回报。您可以将其视为REINFORCE的复杂版本,后者是基础的策略优化算法。有关更多信息,请参阅近端策略优化算法论文。

PPO通常被认为是一种快速高效的在线策略强化算法。TorchRL提供了一个损失模块,它会为您完成所有工作,因此您可以依赖此实现,专注于解决问题,而不是每次训练策略时都重新发明轮子。

为了完整性,这里简要概述了损失计算的内容,尽管这些已由我们的ClipPPOLoss模块处理——算法的工作原理如下:1. 我们通过在环境中执行策略来采样一批数据。2. 然后,我们将使用裁剪版的REINFORCE损失,对该批数据的随机子样本进行一定次数的优化步骤。3. 裁剪将对我们的损失设置一个悲观的上限:较低的回报估计值将比更高的回报估计值更受青睐。损失的具体公式如下:

\[L(s,a,\theta_k,\theta) = \min\left( \frac{\pi_{\theta}(a|s)}{\pi_{\theta_k}(a|s)} A^{\pi_{\theta_k}}(s,a), \;\; g(\epsilon, A^{\pi_{\theta_k}}(s,a)) \right),\]

损失函数中有两个组成部分:在最小值算子的第一部分,我们简单地计算了一个基于重要性加权的 REINFORCE 损失(例如,一个 REINFORCE 损失,我们对其进行了修正,以考虑当前策略配置与用于数据收集的策略配置之间的差异)。最小值算子的第二部分是一个类似的损失,但我们在比率超过或低于给定阈值对时进行了裁剪。

这种损失确保无论优势是正还是负,都会抑制那些可能导致与先前配置产生显著变化的策略更新。

本教程的结构如下:

  1. 首先,我们将定义一组用于训练的超参数。

  2. 接下来,我们将专注于使用 TorchRL 的封装器和转换器来创建我们的环境或模拟器。

  3. 然后,我们将设计策略网络和价值模型,这些模型对于损失函数至关重要。这些模块将用于配置我们的损失模块。

  4. 接着,我们将创建回放缓冲区和数据加载器。

  5. 最后,我们将运行训练循环并分析结果。

在本教程中,我们将使用 tensordict 库。TensorDict 是 TorchRL 的通用语言:它帮助我们抽象模块的读写操作,让我们更少关注具体的数据描述,而更多地关注算法本身。

importwarnings
warnings.filterwarnings("ignore")
fromtorchimport multiprocessing


fromcollectionsimport defaultdict

importmatplotlib.pyplotasplt
importtorch
fromtensordict.nnimport TensorDictModule
fromtensordict.nn.distributionsimport NormalParamExtractor
fromtorchimport nn
fromtorchrl.collectorsimport SyncDataCollector
fromtorchrl.data.replay_buffersimport ReplayBuffer
fromtorchrl.data.replay_buffers.samplersimport SamplerWithoutReplacement
fromtorchrl.data.replay_buffers.storagesimport LazyTensorStorage
fromtorchrl.envsimport (Compose, DoubleToFloat, ObservationNorm, StepCounter,
                          TransformedEnv)
fromtorchrl.envs.libs.gymimport GymEnv
fromtorchrl.envs.utilsimport check_env_specs, ExplorationType, set_exploration_type
fromtorchrl.modulesimport ProbabilisticActor, TanhNormal, ValueOperator
fromtorchrl.objectivesimport ClipPPOLoss
fromtorchrl.objectives.valueimport GAE
fromtqdmimport tqdm

定义超参数

我们为算法设置了超参数。根据可用资源,可以选择在 GPU 或其他设备上执行策略。frame_skip 将控制单个动作在多少帧内执行。其余涉及帧计数的参数必须针对此值进行调整(因为一个环境步骤实际上会返回 frame_skip 帧)。

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
num_cells = 256  # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0

数据收集参数

在收集数据时,我们可以通过定义 frames_per_batch 参数来选择每个批次的大小。我们还将定义允许使用的帧数(例如与模拟器的交互次数)。一般来说,强化学习算法的目标是尽可能以最少的环境交互次数来完成任务:total_frames 越小越好。

frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000

PPO 参数

在每次数据收集(或批量收集)时,我们将在一定数量的 epochs 上运行优化,每次都会在嵌套的训练循环中消耗我们刚刚获取的所有数据。在这里,sub_batch_size 与上面提到的 frames_per_batch 不同:请记住,我们正在处理来自收集器的“数据批次”,其大小由 frames_per_batch 定义,并且我们将在内部训练循环中进一步将其拆分为更小的子批次。这些子批次的大小由 sub_batch_size 控制。

sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10  # optimization steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

定义环境

在强化学习(RL)中,环境通常指的是模拟器或控制系统。许多库提供了强化学习的模拟环境,包括 Gymnasium(原 OpenAI Gym)、DeepMind 控制套件等。作为一个通用库,TorchRL 的目标是为各种 RL 模拟器提供一个可互换的接口,使您能够轻松地替换不同的环境。例如,只需几行代码就可以创建一个封装的 gym 环境:

base_env = GymEnv("InvertedDoublePendulum-v4", device=device)

在这段代码中有几点需要注意:首先,我们通过调用 GymEnv 包装器创建了环境。如果传递了额外的关键字参数,它们将被传递给 gym.make 方法,从而涵盖了最常见的环境构建命令。或者,也可以直接使用 gym.make(env_name, **kwargs) 创建一个 gym 环境,并将其包装在 GymWrapper 类中。

另外需要注意的是 device 参数:对于 gym 来说,它仅控制输入动作和观察状态的存储设备,但执行始终会在 CPU 上进行。原因很简单,除非另有指定,否则 gym 不支持在设备上执行。对于其他库,我们可以控制执行设备,并且尽可能在存储和执行后端方面保持一致。

转换

我们将向环境中添加一些转换,以准备策略所需的数据。在 Gym 中,这通常通过包装器来实现。而 TorchRL 采用了不同的方法,更类似于其他 PyTorch 领域库,通过使用转换来实现。要为环境添加转换,只需将其包装在 TransformedEnv 实例中,并将转换序列附加到该实例上。转换后的环境将继承被包装环境的设备和元数据,并根据所包含的转换序列对这些数据进行转换。

归一化

首先需要编码的是一个标准化转换。通常来说,最好是让数据大致符合单位高斯分布:为了实现这一点,我们将在环境中运行一定数量的随机步骤,并计算这些观测值的汇总统计量。

我们还会附加两个其他转换:DoubleToFloat 转换将把双精度条目转换为单精度数字,以便策略读取。StepCounter 转换将用于在环境终止之前对步骤进行计数。我们将使用这一指标作为性能的补充衡量标准。

正如我们稍后将看到的,TorchRL 的许多类都依赖于 TensorDict 来进行通信。你可以将其视为具有一些额外张量功能的 Python 字典。实际上,这意味着我们将使用的许多模块需要被告知在接收到的 tensordict 中读取哪个键(in_keys)和写入哪个键(out_keys)。通常,如果省略了 out_keys,则假设 in_keys 条目将就地更新。对于我们的转换,我们唯一感兴趣的条目被称为 "observation",我们的转换层将被指示仅修改这一条目:

env = TransformedEnv(
    base_env,
    Compose(
        # normalize observations
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
        StepCounter(),
    ),
)

正如您可能已经注意到的,我们创建了一个归一化层,但尚未设置其归一化参数。为此,ObservationNorm 可以自动收集我们环境的汇总统计信息:

env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)

ObservationNorm 转换现在已经被填充了位置和比例参数,这些参数将用于对数据进行标准化。

让我们对汇总统计的形状做一个简单的检查:

print("normalization constant shape:", env.transform[0].loc.shape)
normalization constant shape: torch.Size([11])

环境不仅由其模拟器和转换定义,还由一系列元数据定义,这些元数据描述了在执行过程中可以预期的内容。出于效率考虑,TorchRL 在环境规范方面非常严格,但您可以轻松检查您的环境规范是否合适。在我们的示例中,继承自 GymWrapperGymEnv 已经负责为您的环境设置适当的规范,因此您无需关心这一点。

尽管如此,我们通过查看其规范来使用转换后的环境,看看一个具体的示例。有三个规范需要关注:observation_spec 定义了在环境中执行操作时可以预期的内容,reward_spec 指示奖励的范围,最后是 input_spec(包含 action_spec),它表示环境执行单个步骤所需的一切。

print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)
observation_spec: Composite(
    observation: UnboundedContinuous(
        shape=torch.Size([11]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    step_count: BoundedDiscrete(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
        device=cpu,
        dtype=torch.int64,
        domain=discrete),
    device=cpu,
    shape=torch.Size([]))
reward_spec: UnboundedContinuous(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)
input_spec: Composite(
    full_state_spec: Composite(
        step_count: BoundedDiscrete(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
            device=cpu,
            dtype=torch.int64,
            domain=discrete),
        device=cpu,
        shape=torch.Size([])),
    full_action_spec: Composite(
        action: BoundedContinuous(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        device=cpu,
        shape=torch.Size([])),
    device=cpu,
    shape=torch.Size([]))
action_spec (as defined by input_spec): BoundedContinuous(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)

check_env_specs() 函数会执行一个小规模的 rollout,并将其输出与环境规格进行比较。如果没有抛出错误,我们可以确信这些规格已正确定义:

check_env_specs(env)

为了有趣起见,让我们看看一个简单的随机rollout是什么样的。您可以调用env.rollout(n_steps),并查看环境输入和输出的概览。动作将自动从动作规范域中抽取,因此您无需关心设计随机采样器。

通常,在每一步中,强化学习环境接收一个动作作为输入,并输出一个观测值、一个奖励和一个完成状态。观测值可能是复合的,这意味着它可能由多个张量组成。这对TorchRL来说不是问题,因为所有的观测值会自动打包到输出的TensorDict中。在给定步数内执行一个rollout(例如,一系列环境步骤和随机动作生成)后,我们将检索到一个形状与该轨迹长度匹配的TensorDict实例:

rollout = env.rollout(3)
print("rollout of three steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)
rollout of three steps: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
Shape of the rollout TensorDict: torch.Size([3])

我们的 rollout 数据具有 torch.Size([3]) 的形状,这与我们运行的步数相匹配。"next" 条目指向当前步骤之后的数据。在大多数情况下,时间 t 的 "next" 数据与 t+1 时的数据相匹配,但如果我们使用一些特定的转换(例如,多步转换),则可能并非如此。

策略

PPO 采用随机策略来处理探索问题。这意味着我们的神经网络需要输出分布参数,而不是与所采取动作对应的单一值。

由于数据是连续的,我们使用 Tanh-Normal 分布来确保动作空间的边界得到遵守。TorchRL 提供了这种分布,我们唯一需要关心的是构建一个神经网络,使其输出适当的参数数量以供策略使用(一个位置或均值,以及一个尺度):

\[f_{\theta}(\text{观测}) = \mu_{\theta}(\text{观测}), \sigma^{+}_{\theta}(\text{观测})\]

这里唯一额外的难点是将输出分成两个相等的部分,并将第二部分映射到严格为正的空间。

我们分三步设计策略:

  1. 定义一个神经网络 D_obs -> 2 * D_action。实际上,我们的 loc (mu) 和 scale (sigma) 都具有维度 D_action

  2. 添加一个 NormalParamExtractor 来提取位置和尺度参数(例如,将输入分成两个相等的部分,并对尺度参数应用一个正变换)。

  3. 创建一个可以生成该分布并从中采样的概率性 TensorDictModule

actor_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
    NormalParamExtractor(),
)

为了使策略能够通过tensordict数据载体与环境“对话”,我们将nn.Module包装在TensorDictModule中。该类将简单地读取提供的in_keys,并将输出写入注册的out_keys中。

policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)

现在我们需要根据正态分布的位置和尺度构建一个分布。为此,我们指示 ProbabilisticActor 类根据位置和尺度参数构建一个 TanhNormal。我们还提供了该分布的最小值和最大值,这些值是从环境规格中获取的。

in_keys 的名称(因此也是上面 TensorDictModuleout_keys 的名称)不能随意设置,因为 TanhNormal 分布构造函数会期望 locscale 关键字参数。也就是说,ProbabilisticActor 也接受 Dict[str, str] 类型的 in_keys,其中键值对指示每个要使用的关键字参数应使用哪个 in_key 字符串。

policy_module = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.action_spec.space.low,
        "high": env.action_spec.space.high,
    },
    return_log_prob=True,
    # we'll need the log-prob for the numerator of the importance weights
)

价值网络

价值网络是 PPO 算法中的关键组成部分,尽管在推理时不会使用它。该模块将读取观察值并返回对后续轨迹的折现回报的估计。这使得我们能够通过在训练过程中动态学习的效用估计来分摊学习成本。我们的价值网络与策略网络共享相同的结构,但为了简单起见,我们为其分配了独立的参数集。

value_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(1, device=device),
)

value_module = ValueOperator(
    module=value_net,
    in_keys=["observation"],
)

让我们来尝试使用我们的策略和价值模块。正如我们之前提到的,使用 TensorDictModule 可以直接读取环境的输出来运行这些模块,因为它们知道要读取哪些信息以及在哪里写入这些信息:

print("Running policy:", policy_module(env.reset()))
print("Running value:", value_module(env.reset()))
Running policy: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
Running value: TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

数据收集器

TorchRL 提供了一组 DataCollector 类。简而言之,这些类执行三个操作:重置环境、根据最新观察计算动作、在环境中执行一步,然后重复最后两个步骤,直到环境发出停止信号(或达到完成状态)。

它们允许您控制每次迭代要收集多少帧数据(通过 frames_per_batch 参数)、何时重置环境(通过 max_frames_per_traj 参数)、策略应在哪个 device 上执行等等。它们还设计为能够高效地与批处理和多进程环境一起工作。

最简单的数据收集器是 SyncDataCollector:它是一个迭代器,您可以使用它来获取给定长度的数据批次,并且一旦收集到指定数量的帧(total_frames),它就会停止。其他数据收集器(MultiSyncDataCollectorMultiaSyncDataCollector)将以同步和异步方式在一组多进程工作器上执行相同的操作。

与之前的策略和环境一样,数据收集器将返回 TensorDict 实例,其元素总数将与 frames_per_batch 匹配。使用 TensorDict 将数据传递给训练循环,可以让您编写的数据加载管道完全忽略 rollout 内容的具体细节。

collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)

回放缓冲区

回放缓冲区是离策略强化学习算法的常见构建模块。在策略环境中,每次收集一批数据时,回放缓冲区都会被重新填充,并且其数据会在一定数量的周期内被重复使用。

TorchRL 的回放缓冲区是使用一个通用容器 ReplayBuffer 构建的,该容器接受缓冲区的组件作为参数:存储、写入器、采样器以及可能的某些转换。只有存储(指示回放缓冲区的容量)是必需的。我们还指定了一个无重复的采样器,以避免在一个周期内多次采样同一项。对于 PPO 来说,使用回放缓冲区并不是强制性的,我们可以简单地从收集的批次中采样子批次,但使用这些类可以让我们以一种可重复的方式轻松构建内部训练循环。

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch),
    sampler=SamplerWithoutReplacement(),
)

损失函数

为了方便起见,可以直接从 TorchRL 中导入 PPO 损失,使用 ClipPPOLoss 类。这是使用 PPO 的最简单方式:它隐藏了 PPO 的数学运算及其相关的控制流程。

PPO 需要计算一些“优势估计”。简而言之,优势是一个反映期望回报值的数值,同时处理偏差/方差的权衡。要计算优势,只需 (1) 构建优势模块,该模块利用我们的价值算子,并且 (2) 在每次迭代之前将每批数据传递给该模块。GAE 模块将使用新的 "advantage""value_target" 条目更新输入的 tensordict"value_target" 是一个无梯度的张量,表示价值网络在输入观测下应表示的实证价值。这两者都将被 ClipPPOLoss 用于返回策略和价值损失。

advantage_module = GAE(
    gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True, device=device,
)

loss_module = ClipPPOLoss(
    actor_network=policy_module,
    critic_network=value_module,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

训练循环

我们现在已经具备了编写训练循环所需的所有部分。步骤包括:

  • 收集数据

    • 计算优势值

    • 遍历收集的数据以计算损失值

    • 反向传播

    • 优化

    • 重复执行

    • 重复

  • 重复

logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""

# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
    # we now have a batch of data to work with. Let's learn something from it.
    for _ in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        advantage_module(tensordict_data)
        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view.cpu())
        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size)
            loss_vals = loss_module(subdata.to(device))
            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            # Optimization: backward, grad clipping and optimization step
            loss_value.backward()
            # this is not strictly mandatory but it's good practice to keep
            # your gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    pbar.update(tensordict_data.numel())
    cum_reward_str = (
        f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
    )
    logs["step_count"].append(tensordict_data["step_count"].max().item())
    stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        # We evaluate the policy once every 10 batches of data.
        # Evaluation is rather simple: execute the policy without exploration
        # (take the expected value of the action distribution) for a given
        # number of steps (1000, which is our ``env`` horizon).
        # The ``rollout`` method of the ``env`` can take a policy as argument:
        # it will then execute this policy at each step.
        with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(1000, policy_module)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(
                eval_rollout["next", "reward"].sum().item()
            )
            logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (
                f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                f"eval step-count: {logs['eval step_count'][-1]}"
            )
            del eval_rollout
    pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))

    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()
  0%|          | 0/50000 [00:00<?, ?it/s]
  2%|2         | 1000/50000 [00:05<04:21, 187.45it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.0897 (init= 9.0897), step count (max): 10, lr policy:  0.0003:   2%|2         | 1000/50000 [00:05<04:21, 187.45it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.0897 (init= 9.0897), step count (max): 10, lr policy:  0.0003:   4%|4         | 2000/50000 [00:10<04:06, 195.08it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.1393 (init= 9.0897), step count (max): 15, lr policy:  0.0003:   4%|4         | 2000/50000 [00:10<04:06, 195.08it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.1393 (init= 9.0897), step count (max): 15, lr policy:  0.0003:   6%|6         | 3000/50000 [00:15<04:05, 191.11it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.1622 (init= 9.0897), step count (max): 15, lr policy:  0.0003:   6%|6         | 3000/50000 [00:15<04:05, 191.11it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.1622 (init= 9.0897), step count (max): 15, lr policy:  0.0003:   8%|8         | 4000/50000 [00:20<03:54, 196.42it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2023 (init= 9.0897), step count (max): 26, lr policy:  0.0003:   8%|8         | 4000/50000 [00:20<03:54, 196.42it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2023 (init= 9.0897), step count (max): 26, lr policy:  0.0003:  10%|#         | 5000/50000 [00:25<03:45, 199.99it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2179 (init= 9.0897), step count (max): 30, lr policy:  0.0003:  10%|#         | 5000/50000 [00:25<03:45, 199.99it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2179 (init= 9.0897), step count (max): 30, lr policy:  0.0003:  12%|#2        | 6000/50000 [00:30<03:37, 202.22it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2267 (init= 9.0897), step count (max): 28, lr policy:  0.0003:  12%|#2        | 6000/50000 [00:30<03:37, 202.22it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2267 (init= 9.0897), step count (max): 28, lr policy:  0.0003:  14%|#4        | 7000/50000 [00:35<03:30, 204.06it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2419 (init= 9.0897), step count (max): 48, lr policy:  0.0003:  14%|#4        | 7000/50000 [00:35<03:30, 204.06it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2419 (init= 9.0897), step count (max): 48, lr policy:  0.0003:  16%|#6        | 8000/50000 [00:40<03:27, 202.16it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2473 (init= 9.0897), step count (max): 32, lr policy:  0.0003:  16%|#6        | 8000/50000 [00:40<03:27, 202.16it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2473 (init= 9.0897), step count (max): 32, lr policy:  0.0003:  18%|#8        | 9000/50000 [00:44<03:21, 203.78it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2294 (init= 9.0897), step count (max): 38, lr policy:  0.0003:  18%|#8        | 9000/50000 [00:44<03:21, 203.78it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2294 (init= 9.0897), step count (max): 38, lr policy:  0.0003:  20%|##        | 10000/50000 [00:49<03:15, 205.08it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2370 (init= 9.0897), step count (max): 34, lr policy:  0.0003:  20%|##        | 10000/50000 [00:49<03:15, 205.08it/s]
eval cumulative reward:  119.4005 (init:  119.4005), eval step-count: 12, average reward= 9.2370 (init= 9.0897), step count (max): 34, lr policy:  0.0003:  22%|##2       | 11000/50000 [00:54<03:09, 205.97it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2572 (init= 9.0897), step count (max): 45, lr policy:  0.0003:  22%|##2       | 11000/50000 [00:54<03:09, 205.97it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2572 (init= 9.0897), step count (max): 45, lr policy:  0.0003:  24%|##4       | 12000/50000 [00:59<03:04, 205.78it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2417 (init= 9.0897), step count (max): 39, lr policy:  0.0003:  24%|##4       | 12000/50000 [00:59<03:04, 205.78it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2417 (init= 9.0897), step count (max): 39, lr policy:  0.0003:  26%|##6       | 13000/50000 [01:04<02:59, 206.60it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2526 (init= 9.0897), step count (max): 46, lr policy:  0.0003:  26%|##6       | 13000/50000 [01:04<02:59, 206.60it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2526 (init= 9.0897), step count (max): 46, lr policy:  0.0003:  28%|##8       | 14000/50000 [01:08<02:53, 207.11it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2410 (init= 9.0897), step count (max): 63, lr policy:  0.0003:  28%|##8       | 14000/50000 [01:08<02:53, 207.11it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2410 (init= 9.0897), step count (max): 63, lr policy:  0.0003:  30%|###       | 15000/50000 [01:13<02:48, 207.31it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2686 (init= 9.0897), step count (max): 59, lr policy:  0.0002:  30%|###       | 15000/50000 [01:13<02:48, 207.31it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2686 (init= 9.0897), step count (max): 59, lr policy:  0.0002:  32%|###2      | 16000/50000 [01:18<02:43, 208.54it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2717 (init= 9.0897), step count (max): 69, lr policy:  0.0002:  32%|###2      | 16000/50000 [01:18<02:43, 208.54it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2717 (init= 9.0897), step count (max): 69, lr policy:  0.0002:  34%|###4      | 17000/50000 [01:23<02:37, 209.42it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2714 (init= 9.0897), step count (max): 74, lr policy:  0.0002:  34%|###4      | 17000/50000 [01:23<02:37, 209.42it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2714 (init= 9.0897), step count (max): 74, lr policy:  0.0002:  36%|###6      | 18000/50000 [01:28<02:35, 206.24it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2675 (init= 9.0897), step count (max): 64, lr policy:  0.0002:  36%|###6      | 18000/50000 [01:28<02:35, 206.24it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2675 (init= 9.0897), step count (max): 64, lr policy:  0.0002:  38%|###8      | 19000/50000 [01:33<02:29, 207.34it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2829 (init= 9.0897), step count (max): 72, lr policy:  0.0002:  38%|###8      | 19000/50000 [01:33<02:29, 207.34it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2829 (init= 9.0897), step count (max): 72, lr policy:  0.0002:  40%|####      | 20000/50000 [01:37<02:24, 208.03it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2735 (init= 9.0897), step count (max): 74, lr policy:  0.0002:  40%|####      | 20000/50000 [01:37<02:24, 208.03it/s]
eval cumulative reward:  213.3145 (init:  119.4005), eval step-count: 22, average reward= 9.2735 (init= 9.0897), step count (max): 74, lr policy:  0.0002:  42%|####2     | 21000/50000 [01:42<02:18, 209.07it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.2928 (init= 9.0897), step count (max): 72, lr policy:  0.0002:  42%|####2     | 21000/50000 [01:42<02:18, 209.07it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.2928 (init= 9.0897), step count (max): 72, lr policy:  0.0002:  44%|####4     | 22000/50000 [01:47<02:14, 207.95it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.2976 (init= 9.0897), step count (max): 101, lr policy:  0.0002:  44%|####4     | 22000/50000 [01:47<02:14, 207.95it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.2976 (init= 9.0897), step count (max): 101, lr policy:  0.0002:  46%|####6     | 23000/50000 [01:52<02:09, 208.92it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.2875 (init= 9.0897), step count (max): 77, lr policy:  0.0002:  46%|####6     | 23000/50000 [01:52<02:09, 208.92it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.2875 (init= 9.0897), step count (max): 77, lr policy:  0.0002:  48%|####8     | 24000/50000 [01:56<02:03, 209.70it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.2983 (init= 9.0897), step count (max): 91, lr policy:  0.0002:  48%|####8     | 24000/50000 [01:56<02:03, 209.70it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.2983 (init= 9.0897), step count (max): 91, lr policy:  0.0002:  50%|#####     | 25000/50000 [02:01<01:58, 210.28it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3110 (init= 9.0897), step count (max): 143, lr policy:  0.0002:  50%|#####     | 25000/50000 [02:01<01:58, 210.28it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3110 (init= 9.0897), step count (max): 143, lr policy:  0.0002:  52%|#####2    | 26000/50000 [02:06<01:54, 210.22it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3021 (init= 9.0897), step count (max): 94, lr policy:  0.0001:  52%|#####2    | 26000/50000 [02:06<01:54, 210.22it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3021 (init= 9.0897), step count (max): 94, lr policy:  0.0001:  54%|#####4    | 27000/50000 [02:11<01:49, 210.24it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3066 (init= 9.0897), step count (max): 92, lr policy:  0.0001:  54%|#####4    | 27000/50000 [02:11<01:49, 210.24it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3066 (init= 9.0897), step count (max): 92, lr policy:  0.0001:  56%|#####6    | 28000/50000 [02:16<01:46, 207.46it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3075 (init= 9.0897), step count (max): 131, lr policy:  0.0001:  56%|#####6    | 28000/50000 [02:16<01:46, 207.46it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3075 (init= 9.0897), step count (max): 131, lr policy:  0.0001:  58%|#####8    | 29000/50000 [02:20<01:40, 208.34it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3161 (init= 9.0897), step count (max): 133, lr policy:  0.0001:  58%|#####8    | 29000/50000 [02:20<01:40, 208.34it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3161 (init= 9.0897), step count (max): 133, lr policy:  0.0001:  60%|######    | 30000/50000 [02:25<01:35, 208.41it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3220 (init= 9.0897), step count (max): 177, lr policy:  0.0001:  60%|######    | 30000/50000 [02:25<01:35, 208.41it/s]
eval cumulative reward:  417.8861 (init:  119.4005), eval step-count: 44, average reward= 9.3220 (init= 9.0897), step count (max): 177, lr policy:  0.0001:  62%|######2   | 31000/50000 [02:35<01:57, 161.77it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3202 (init= 9.0897), step count (max): 125, lr policy:  0.0001:  62%|######2   | 31000/50000 [02:35<01:57, 161.77it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3202 (init= 9.0897), step count (max): 125, lr policy:  0.0001:  64%|######4   | 32000/50000 [02:40<01:45, 170.55it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3177 (init= 9.0897), step count (max): 125, lr policy:  0.0001:  64%|######4   | 32000/50000 [02:40<01:45, 170.55it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3177 (init= 9.0897), step count (max): 125, lr policy:  0.0001:  66%|######6   | 33000/50000 [02:44<01:33, 181.30it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3245 (init= 9.0897), step count (max): 167, lr policy:  0.0001:  66%|######6   | 33000/50000 [02:44<01:33, 181.30it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3245 (init= 9.0897), step count (max): 167, lr policy:  0.0001:  68%|######8   | 34000/50000 [02:49<01:24, 189.45it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3194 (init= 9.0897), step count (max): 163, lr policy:  0.0001:  68%|######8   | 34000/50000 [02:49<01:24, 189.45it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3194 (init= 9.0897), step count (max): 163, lr policy:  0.0001:  70%|#######   | 35000/50000 [02:54<01:16, 195.82it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3277 (init= 9.0897), step count (max): 201, lr policy:  0.0001:  70%|#######   | 35000/50000 [02:54<01:16, 195.82it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3277 (init= 9.0897), step count (max): 201, lr policy:  0.0001:  72%|#######2  | 36000/50000 [02:59<01:09, 200.36it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3240 (init= 9.0897), step count (max): 151, lr policy:  0.0001:  72%|#######2  | 36000/50000 [02:59<01:09, 200.36it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3240 (init= 9.0897), step count (max): 151, lr policy:  0.0001:  74%|#######4  | 37000/50000 [03:03<01:04, 201.14it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3332 (init= 9.0897), step count (max): 268, lr policy:  0.0001:  74%|#######4  | 37000/50000 [03:03<01:04, 201.14it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3332 (init= 9.0897), step count (max): 268, lr policy:  0.0001:  76%|#######6  | 38000/50000 [03:08<00:58, 204.32it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3328 (init= 9.0897), step count (max): 225, lr policy:  0.0000:  76%|#######6  | 38000/50000 [03:08<00:58, 204.32it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3328 (init= 9.0897), step count (max): 225, lr policy:  0.0000:  78%|#######8  | 39000/50000 [03:13<00:53, 206.15it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3379 (init= 9.0897), step count (max): 328, lr policy:  0.0000:  78%|#######8  | 39000/50000 [03:13<00:53, 206.15it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3379 (init= 9.0897), step count (max): 328, lr policy:  0.0000:  80%|########  | 40000/50000 [03:18<00:48, 207.52it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3348 (init= 9.0897), step count (max): 384, lr policy:  0.0000:  80%|########  | 40000/50000 [03:18<00:48, 207.52it/s]
eval cumulative reward:  1269.8306 (init:  119.4005), eval step-count: 135, average reward= 9.3348 (init= 9.0897), step count (max): 384, lr policy:  0.0000:  82%|########2 | 41000/50000 [03:22<00:43, 208.51it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3394 (init= 9.0897), step count (max): 254, lr policy:  0.0000:  82%|########2 | 41000/50000 [03:23<00:43, 208.51it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3394 (init= 9.0897), step count (max): 254, lr policy:  0.0000:  84%|########4 | 42000/50000 [03:28<00:39, 202.19it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3347 (init= 9.0897), step count (max): 256, lr policy:  0.0000:  84%|########4 | 42000/50000 [03:28<00:39, 202.19it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3347 (init= 9.0897), step count (max): 256, lr policy:  0.0000:  86%|########6 | 43000/50000 [03:32<00:34, 204.61it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3285 (init= 9.0897), step count (max): 176, lr policy:  0.0000:  86%|########6 | 43000/50000 [03:32<00:34, 204.61it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3285 (init= 9.0897), step count (max): 176, lr policy:  0.0000:  88%|########8 | 44000/50000 [03:37<00:29, 205.88it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3364 (init= 9.0897), step count (max): 280, lr policy:  0.0000:  88%|########8 | 44000/50000 [03:37<00:29, 205.88it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3364 (init= 9.0897), step count (max): 280, lr policy:  0.0000:  90%|######### | 45000/50000 [03:42<00:24, 207.64it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3445 (init= 9.0897), step count (max): 321, lr policy:  0.0000:  90%|######### | 45000/50000 [03:42<00:24, 207.64it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3445 (init= 9.0897), step count (max): 321, lr policy:  0.0000:  92%|#########2| 46000/50000 [03:47<00:19, 205.67it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3335 (init= 9.0897), step count (max): 289, lr policy:  0.0000:  92%|#########2| 46000/50000 [03:47<00:19, 205.67it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3335 (init= 9.0897), step count (max): 289, lr policy:  0.0000:  94%|#########3| 47000/50000 [03:52<00:14, 205.79it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3373 (init= 9.0897), step count (max): 313, lr policy:  0.0000:  94%|#########3| 47000/50000 [03:52<00:14, 205.79it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3373 (init= 9.0897), step count (max): 313, lr policy:  0.0000:  96%|#########6| 48000/50000 [03:57<00:09, 207.24it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3350 (init= 9.0897), step count (max): 278, lr policy:  0.0000:  96%|#########6| 48000/50000 [03:57<00:09, 207.24it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3350 (init= 9.0897), step count (max): 278, lr policy:  0.0000:  98%|#########8| 49000/50000 [04:01<00:04, 208.39it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3329 (init= 9.0897), step count (max): 404, lr policy:  0.0000:  98%|#########8| 49000/50000 [04:01<00:04, 208.39it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3329 (init= 9.0897), step count (max): 404, lr policy:  0.0000: 100%|##########| 50000/50000 [04:06<00:00, 209.08it/s]
eval cumulative reward:  1737.9119 (init:  119.4005), eval step-count: 185, average reward= 9.3431 (init= 9.0897), step count (max): 495, lr policy:  0.0000: 100%|##########| 50000/50000 [04:06<00:00, 209.08it/s]

结果

在达到 100 万步的限制之前,算法应已达到 1000 步的最大步数,这是轨迹被截断前的最大步数。

plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()

training rewards (average), Max step count (training), Return (test), Max step count (test)

结论与下一步

在本教程中,我们已经学习了:

  1. 如何使用 torchrl 创建和自定义环境;

  2. 如何编写模型和损失函数;

  3. 如何设置一个典型的训练循环。

如果您想进一步尝试本教程,可以应用以下修改:

  • 从效率角度考虑,我们可以并行运行多个模拟以加快数据收集速度。更多信息请查看 ParallelEnv

  • 从日志记录角度考虑,可以在请求渲染后向环境中添加 torchrl.record.VideoRecorder 变换,以获取倒立摆动作的视觉渲染。了解更多请查看 torchrl.record

本页目录