TorchScript的TracedModule和ScriptModule的区别

最近在读TorchScript的入门介绍,看了官方链接的文章,然后感觉是云山雾罩,不知所云。

然后搜索到了Rene Wang的文章,才感觉明白了好多。

官方的介绍TracedModule的缺点例子是这样的:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)

输出是:

def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(input, ), )
  _1 = torch.tanh(torch.add(_0, h, alpha=1))
  return (_1, _1)

然后官方再介绍ScriptMoudle:

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)

然后输出是:

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)

然后文章里就高潮叫hooray了,我还是一脸懵逼的,根本没有看到ScriptModule的code与TracedModule的code差异啊?

Rene Wang的文章解释的很到位,关键要看my_cell.dg.code,其实他们是这样的

traced_gate = torch.jit.trace(my_cell.dg, (x,))
print(traced_gate.code)

--输出--
c:\python36\lib\site-packages\ipykernel_launcher.py:4: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  after removing the cwd from sys.path.

def forward(self,
    x: Tensor) -> Tensor:
  return x
scripted_gate = torch.jit.script(MyDecisionGate())
print(scripted_gate.code)
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell)
print(traced_cell.code)
#只有从dg.code才能看到 if else 流程控制语句执行了
print(traced_cell.dg.code)

--输出--

def forward(self,
    x: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  if _0:
    _1 = x
  else:
    _1 = torch.neg(x)
  return _1

RecursiveScriptModule(
  original_name=MyCell
  (dg): RecursiveScriptModule(original_name=MyDecisionGate)
  (linear): RecursiveScriptModule(original_name=Linear)
)
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)

def forward(self,
    x: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  if _0:
    _1 = x
  else:
    _1 = torch.neg(x)
  return _1

这样能够清晰的看到ScriptModule追踪到了if else 控制流。

基于torch 1.4.0版本,可能官方的tutorial是基于老的版本的实例。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章