TorchScript的TracedModule和ScriptModule的區別

最近在讀TorchScript的入門介紹,看了官方鏈接的文章,然後感覺是雲山霧罩,不知所云。

然後搜索到了Rene Wang的文章,才感覺明白了好多。

官方的介紹TracedModule的缺點例子是這樣的:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)

輸出是:

def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(input, ), )
  _1 = torch.tanh(torch.add(_0, h, alpha=1))
  return (_1, _1)

然後官方再介紹ScriptMoudle:

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)

然後輸出是:

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)

然後文章裏就高潮叫hooray了,我還是一臉懵逼的,根本沒有看到ScriptModule的code與TracedModule的code差異啊?

Rene Wang的文章解釋的很到位,關鍵要看my_cell.dg.code,其實他們是這樣的

traced_gate = torch.jit.trace(my_cell.dg, (x,))
print(traced_gate.code)

--輸出--
c:\python36\lib\site-packages\ipykernel_launcher.py:4: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  after removing the cwd from sys.path.

def forward(self,
    x: Tensor) -> Tensor:
  return x
scripted_gate = torch.jit.script(MyDecisionGate())
print(scripted_gate.code)
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell)
print(traced_cell.code)
#只有從dg.code才能看到 if else 流程控制語句執行了
print(traced_cell.dg.code)

--輸出--

def forward(self,
    x: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  if _0:
    _1 = x
  else:
    _1 = torch.neg(x)
  return _1

RecursiveScriptModule(
  original_name=MyCell
  (dg): RecursiveScriptModule(original_name=MyDecisionGate)
  (linear): RecursiveScriptModule(original_name=Linear)
)
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)

def forward(self,
    x: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  if _0:
    _1 = x
  else:
    _1 = torch.neg(x)
  return _1

這樣能夠清晰的看到ScriptModule追蹤到了if else 控制流。

基於torch 1.4.0版本,可能官方的tutorial是基於老的版本的實例。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章