TorchScript中的动态并行性¶
Created On: Jul 28, 2020 | Last Updated: Dec 02, 2024 | Last Verified: Nov 05, 2024
警告
TorchScript 不再处于活跃开发状态。
在本教程中,我们介绍了在TorchScript中进行*动态互操作并行性*的语法。这种并行性具有以下特点:
动态 - 创建的并行任务数量及其工作量可以依赖于程序的控制流。
互操作 - 这种并行性关注于并行运行TorchScript程序片段。这与*内操作并行性*不同,后者关注于将单个算子分解,并并行运行算子的子集。
基本语法¶
动态并行性的两个重要API是:
torch.jit.fork(fn : Callable[..., T], *args, **kwargs) -> torch.jit.Future[T]
torch.jit.wait(fut : torch.jit.Future[T]) -> T
通过示例可以很好地演示这些API的工作方式:
import torch
def foo(x):
return torch.neg(x)
@torch.jit.script
def example(x):
# Call `foo` using parallelism:
# First, we "fork" off a task. This task will run `foo` with argument `x`
future = torch.jit.fork(foo, x)
# Call `foo` normally
x_normal = foo(x)
# Second, we "wait" on the task. Since the task may be running in
# parallel, we have to "wait" for its result to become available.
# Notice that by having lines of code between the "fork()" and "wait()"
# call for a given Future, we can overlap computations so that they
# run in parallel.
x_parallel = torch.jit.wait(future)
return x_normal, x_parallel
print(example(torch.ones(1))) # (-1., -1.)
fork()``接受可调用对象``fn``及其参数``args``和``kwargs
,并创建一个异步任务来执行``fn``。fn``可以是函数、方法或模块实例。``fork()``返回一个对执行结果值的引用,称为``Future
。由于``fork``在创建异步任务后立即返回,``fn``可能还未在执行完毕时就执行了``fork()``调用后的代码行。因此,需要使用``wait()``等待异步任务完成并返回值。
这些结构可以用于重叠函数中的语句执行(见实际示例章节)或与循环等其他语言结构组合:
import torch
from typing import List
def foo(x):
return torch.neg(x)
@torch.jit.script
def example(x):
futures : List[torch.jit.Future[torch.Tensor]] = []
for _ in range(100):
futures.append(torch.jit.fork(foo, x))
results = []
for future in futures:
results.append(torch.jit.wait(future))
return torch.sum(torch.stack(results))
print(example(torch.ones([])))
备注
当我们初始化一个空的Futures列表时,需要对``futures``添加显式的类型注解。在TorchScript中,空容器默认假定它们包含Tensor值,因此我们在列表构造函数中注解类型为``List[torch.jit.Future[torch.Tensor]]``
此示例使用``fork()``启动函数``foo``的100个实例,等待100个任务完成,然后求和结果,返回``-100.0``。
应用示例:双向LSTM的集成¶
让我们尝试将并行性应用于一个更实际的示例,并看看我们能从中获得什么样的性能。首先,让我们定义基准模型:一个双向LSTM层的集成。
import torch, time
# In RNN parlance, the dimensions we care about are:
# # of time-steps (T)
# Batch size (B)
# Hidden size/number of "channels" (C)
T, B, C = 50, 50, 1024
# A module that defines a single "bidirectional LSTM". This is simply two
# LSTMs applied to the same sequence, but one in reverse
class BidirectionalRecurrentLSTM(torch.nn.Module):
def __init__(self):
super().__init__()
self.cell_f = torch.nn.LSTM(input_size=C, hidden_size=C)
self.cell_b = torch.nn.LSTM(input_size=C, hidden_size=C)
def forward(self, x : torch.Tensor) -> torch.Tensor:
# Forward layer
output_f, _ = self.cell_f(x)
# Backward layer. Flip input in the time dimension (dim 0), apply the
# layer, then flip the outputs in the time dimension
x_rev = torch.flip(x, dims=[0])
output_b, _ = self.cell_b(torch.flip(x, dims=[0]))
output_b_rev = torch.flip(output_b, dims=[0])
return torch.cat((output_f, output_b_rev), dim=2)
# An "ensemble" of `BidirectionalRecurrentLSTM` modules. The modules in the
# ensemble are run one-by-one on the same input then their results are
# stacked and summed together, returning the combined result.
class LSTMEnsemble(torch.nn.Module):
def __init__(self, n_models):
super().__init__()
self.n_models = n_models
self.models = torch.nn.ModuleList([
BidirectionalRecurrentLSTM() for _ in range(self.n_models)])
def forward(self, x : torch.Tensor) -> torch.Tensor:
results = []
for model in self.models:
results.append(model(x))
return torch.stack(results).sum(dim=0)
# For a head-to-head comparison to what we're going to do with fork/wait, let's
# instantiate the model and compile it with TorchScript
ens = torch.jit.script(LSTMEnsemble(n_models=4))
# Normally you would pull this input out of an embedding table, but for the
# purpose of this demo let's just use random data.
x = torch.rand(T, B, C)
# Let's run the model once to warm up things like the memory allocator
ens(x)
x = torch.rand(T, B, C)
# Let's see how fast it runs!
s = time.time()
ens(x)
print('Inference took', time.time() - s, ' seconds')
在我的机器上,这种网络运行需要``2.05``秒。我们可以做得更好!
并行化前向和后向层¶
一个非常简单的操作是将``BidirectionalRecurrentLSTM``中的前向和后向层进行并行化。对于此操作,计算结构是静态的,因此实际上我们甚至不需要任何循环。让我们像这样重写``BidirectionalRecurrentLSTM``的``forward``方法:
def forward(self, x : torch.Tensor) -> torch.Tensor:
# Forward layer - fork() so this can run in parallel to the backward
# layer
future_f = torch.jit.fork(self.cell_f, x)
# Backward layer. Flip input in the time dimension (dim 0), apply the
# layer, then flip the outputs in the time dimension
x_rev = torch.flip(x, dims=[0])
output_b, _ = self.cell_b(torch.flip(x, dims=[0]))
output_b_rev = torch.flip(output_b, dims=[0])
# Retrieve the output from the forward layer. Note this needs to happen
# *after* the stuff we want to parallelize with
output_f, _ = torch.jit.wait(future_f)
return torch.cat((output_f, output_b_rev), dim=2)
在此示例中,forward()``将``cell_f``的执行委托给另一个线程,同时继续执行``cell_b
。这使得两个单元的执行相互重叠。
重新运行修改后的脚本后,运行时间降低到``1.71``秒,提升了``17%``!
附注:并行性可视化¶
我们尚未完成对模型的优化,但值得介绍我们用于可视化性能的工具。其中一个重要工具是 PyTorch profiler。
让我们用分析器以及 Chrome 的追踪导出功能来可视化我们的并行化模型的性能:
with torch.autograd.profiler.profile() as prof:
ens(x)
prof.export_chrome_trace('parallel.json')
这段代码将写出一个名为 parallel.json
的文件。如果你在 Google Chrome 中导航到 chrome://tracing
,点击 Load
按钮并加载该 JSON 文件,你应该能看到如下的时间线:

时间线的水平轴表示时间,垂直轴表示执行线程。正如我们所见,我们一次运行两个 lstm
实例。这是我们辛苦工作的结果并行化双向层!
对模型集中的模型进行并行化¶
你可能注意到我们的代码中还有另一个并行化的机会:我们可以同时并行运行 LSTMEnsemble
中的模型。实现方法很简单,这就是我们应该如何修改 LSTMEnsemble
的 forward
方法:
def forward(self, x : torch.Tensor) -> torch.Tensor:
# Launch tasks for each model
futures : List[torch.jit.Future[torch.Tensor]] = []
for model in self.models:
futures.append(torch.jit.fork(model, x))
# Collect the results from the launched tasks
results : List[torch.Tensor] = []
for future in futures:
results.append(torch.jit.wait(future))
return torch.stack(results).sum(dim=0)
或者,如果你喜欢简洁,我们可以使用列表推导式:
def forward(self, x : torch.Tensor) -> torch.Tensor:
futures = [torch.jit.fork(model, x) for model in self.models]
results = [torch.jit.wait(fut) for fut in futures]
return torch.stack(results).sum(dim=0)
正如介绍中所述,我们使用循环来为模型集中每个模型创建任务。然后我们使用另一个循环等待所有任务完成。这提供了更多的计算重叠。
通过这个小更新,脚本运行时间缩短至 1.4
秒,总提速达到 32%
!仅两行代码就能取得不错的效果。
我们还可以再次使用 Chrome 追踪器来查看具体发生了什么:

现在我们可以看到所有 LSTM
实例完全处于并行运行状态。
总结¶
在本教程中,我们了解了``fork()`` 和 wait()
,即在 TorchScript 中进行动态、操作并行性的基本 API。我们看到了使用这些函数并行化函数、方法或``Module``执行的几种典型使用模式。最后,我们通过一个示例展示了如何使用这种技术优化模型,并探索了 PyTorch 提供的性能测量和可视化工具。