最近在讀TorchScript的入門介紹,看了官方鏈接的文章,然後感覺是雲山霧罩,不知所云。
然後搜索到了Rene Wang的文章,才感覺明白了好多。
官方的介紹TracedModule的缺點例子是這樣的:
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)
輸出是:
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (self.dg).forward((self.linear).forward(input, ), )
_1 = torch.tanh(torch.add(_0, h, alpha=1))
return (_1, _1)
然後官方再介紹ScriptMoudle:
scripted_gate = torch.jit.script(MyDecisionGate())
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)
然後輸出是:
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (self.dg).forward((self.linear).forward(x, ), )
new_h = torch.tanh(torch.add(_0, h, alpha=1))
return (new_h, new_h)
然後文章裏就高潮叫hooray了,我還是一臉懵逼的,根本沒有看到ScriptModule的code與TracedModule的code差異啊?
Rene Wang的文章解釋的很到位,關鍵要看my_cell.dg.code,其實他們是這樣的
traced_gate = torch.jit.trace(my_cell.dg, (x,))
print(traced_gate.code)
--輸出--
c:\python36\lib\site-packages\ipykernel_launcher.py:4: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
after removing the cwd from sys.path.
def forward(self,
x: Tensor) -> Tensor:
return x
scripted_gate = torch.jit.script(MyDecisionGate())
print(scripted_gate.code)
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell)
print(traced_cell.code)
#只有從dg.code才能看到 if else 流程控制語句執行了
print(traced_cell.dg.code)
--輸出--
def forward(self,
x: Tensor) -> Tensor:
_0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
if _0:
_1 = x
else:
_1 = torch.neg(x)
return _1
RecursiveScriptModule(
original_name=MyCell
(dg): RecursiveScriptModule(original_name=MyDecisionGate)
(linear): RecursiveScriptModule(original_name=Linear)
)
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (self.dg).forward((self.linear).forward(x, ), )
new_h = torch.tanh(torch.add(_0, h, alpha=1))
return (new_h, new_h)
def forward(self,
x: Tensor) -> Tensor:
_0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
if _0:
_1 = x
else:
_1 = torch.neg(x)
return _1
這樣能夠清晰的看到ScriptModule追蹤到了if else 控制流。
基於torch 1.4.0版本,可能官方的tutorial是基於老的版本的實例。