深度卷積網絡圖像風格轉移(三)代碼分析

理解 Deep Photo Style Transfer源代碼

Taylor Guo, 2017年5月17日

代碼哪裏下載?需要自己動手……

  • Lua提供require函數來加載運行庫
    1.require會搜索目錄加載文件;
    2.require會判斷是否文件已經加載,避免重複加載同一文件。

    require函數實現了不同lua文件的加載,類似於C++中的include,java中的import。require使用的路徑和普通的路徑還是有些區別,我們一般見到的路徑都是一個目錄列表。require的路徑是一個模式列表,每一個模式指明一種由虛文件名(require的參數modname)轉成實文件名的方法。更明確地說,每一個模式是一個包含可選的問號(?)的文件名。匹配的時候Lua會首先將問號用虛文件名替換,然後看是否有這樣的文件存在。如果不存在繼續用同樣的方法用第二個模式匹配。

    – 例如:路徑爲: ?; ?.lua; c:\windows\?; /usr/local/lua/?/?.lua
    調用require(“add”)會打開這些文件:
    add
    add.lua
    c:\windows\add
    /usr/local/lua/add/add.lua
    參考:http://blog.csdn.net/xxxxyyyy2012/article/details/41675345

  • –torch: Torch是Torch7的主要包,其中的數據結構定義了多維數據結構張量和數學計算。還提供了很多有用的工具處理文件,任意類型的序列化,還有其他有用的工具。
    參考:
    https://github.com/torch/torch7
    https://github.com/torch/torch7/wiki/Cheatsheet#cuda

  • –nn:神經網絡包,由不同模塊組成。 Module是抽象類,子類是Container,而container又有三個構建神經網絡最重要的子類:Sequential, Parallel和Concat,由這三類構成的神經網絡中可以包含簡單層,即Linear, Mean, Max和Reshape等,也包括卷積層還有激活函數Tanh, ReLU等。
    參考:
    http://blog.csdn.net/hungryof/article/details/52022415 ; https://github.com/soumith/cvpr2015/blob/master/Deep%20Learning%20with%20Torch.ipynb

    torch nn神經網絡包中構建神經網絡的子類
    這裏寫圖片描述

  • –image:是Torch7發佈包中處理圖像的包。
    包含:
    保存和加載圖像jpeg,png,ppm和PGM;
    簡單變換-平移,縮放和旋轉;
    參數化變換:卷積和捲曲;
    簡單繪畫:寫文字和矩形;
    圖形界面:顯示和窗口;
    顏色空間:轉成或從這裏轉 RGB,YUV,LAB和HSL;
    張量構造:創建Lenna,Fabio,高斯核拉普拉斯內核。
    參考: https://github.com/torch/image

  • –optim:torch7 幾種優化方法和logger
    優化算法:
    Stochastic Gradient Descent
    Averaged Stochastic Gradient Descent
    L-BFGS
    Congugate Gradients
    AdaDelta
    AdaGrad
    Adam
    AdaMax
    FISTA with backtracking line search
    Nesterov’s Accelerated Gradient method
    RMSprop
    Rprop
    CMAES
    參考:https://github.com/torch/optim

  • –loadcaffe:在torch7下加載caffe網絡,無需caffe依賴情況下,加載caffe網絡,只需要安裝protobuf。
    參考:
    https://github.com/szagoruyko/loadcaffe
    https://github.com/torch/torch7/wiki/Cheatsheet
    https://github.com/torch/rocks/blob/master/loadcaffe-1.0-0.rockspec

  • –libcuda_utils: cuda庫
    參考:
    https://github.com/luanfujun/deep-photo-styletransfer

  • –cutorch :是一個torch7的CUDA後端。
    提供了:

    • 一種新的張量類型torch.CudaTensor,可以運行在GPU上。cutorch 支持大部分張量操作。
    • 還支持其他GPU張量類型,不過功能有限。
    • cutorch.*:獲取和設置GPU,設備屬性,內存使用狀況等。
      參考:https://github.com/torch/cutorch
  • –cunn:CUDA後端對神經網絡包的實現
    它基於nn包中的模塊提供了CUDA的實現。
    參考:
    https://github.com/torch/cunn

  • –matio:C語言庫,用於讀取和寫入matlab mat文件
    參考:
    https://github.com/soumith/matio-ffi.torch
    https://sourceforge.net/projects/matio/

代碼

  • LUA代碼主要是:deepmatting_seg.lua 和 neuralstyle_seg.lua。
    -deepmatting_seg.lua裏面有完整的網絡構建、 Laplacian Matting正則項、數值優化和損失函數計算。
    -neuralstyle_seg.lua裏面有完整的網絡構建、 語義分割遮罩、數值優化和損失函數計算。

neuralstyle_seg.lua 代碼


require 'torch'
require 'nn'
require 'image'
require 'optim'

require 'loadcaffe'
require 'libcuda_utils'

require 'cutorch'
require 'cunn'


--lua的變量是全局的,即一個文件中的變量所有的文件都可以訪問。除非加入local進行限制。
--torch.CmdLine() lua文件的CmdLine類用於在命令行下用幾個不同參數時,解析參數。還可以打印輸出定向到log文件中。
--option(名稱,缺省,幫助) 存儲一個可選變量。名稱以“-”開頭。
--參考官網:http://torch7.readthedocs.io/en/latest/cmdline/index.html



local matio = require 'matio'
local cmd = torch.CmdLine()

-- Basic options
cmd:option('-style_image', 'examples/inputs/seated-nude.jpg', 'Style target image')
cmd:option('-content_image', 'examples/inputs/tubingen.jpg','Content target image')
cmd:option('-style_seg', '', 'Style segmentation')
cmd:option('-style_seg_idxs', '', 'Style seg idxs')
cmd:option('-content_seg', '', 'Content segmentation')
cmd:option('-content_seg_idxs', '', 'Content seg idxs')

cmd:option('-gpu', 0, 'Zero-indexed ID of the GPU to use; for CPU mode set -gpu = -1')

-- Optimization options
cmd:option('-content_weight', 5e0)
cmd:option('-style_weight', 1e2)
cmd:option('-tv_weight', 1e-3)
cmd:option('-num_iterations', 1000)

-- Output options
cmd:option('-print_iter', 1)
cmd:option('-save_iter', 100)
cmd:option('-output_image', 'out.png') 
cmd:option('-index', 1)
cmd:option('-serial', 'serial_example') 

-- Other options
cmd:option('-proto_file', 'models/VGG_ILSVRC_19_layers_deploy.prototxt')
cmd:option('-model_file', 'models/VGG_ILSVRC_19_layers.caffemodel')
cmd:option('-backend', 'nn', 'nn|cudnn|clnn')
cmd:option('-cudnn_autotune', false)
cmd:option('-seed', 612)

cmd:option('-content_layers', 'relu4_2', 'layers for content')
cmd:option('-style_layers',   'relu1_1,relu2_1,relu3_1,relu4_1,relu5_1', 'layers for style')

local function main(params)
--設置GPU模式,如果有多個GPU,可以切換缺省GPU(分配CUDA張量做運算)。
-- GPU ID從1開始數,有4個GPU,可以設爲setDevice(1),setDevice(2),setDevice(3),setDevice(4)。
  cutorch.setDevice(params.gpu + 1)
  cutorch.setHeapTracking(true)

--torch初始化的時候,可以用seed()作爲隨機種子生成器提供隨機數生成器。
--也可以用manualSeed()重新初始化。
--manualSeed([gen,]number)用給定的數字number設置隨機數生成器的種子。
  torch.manualSeed(params.seed)

--getDevice():返回當前設置的GPU序號。
  idx = cutorch.getDevice()
  print('gpu, idx = ', params.gpu, idx)

  -- content: pitie transferred input image
  --image.load(文件名,[通道(1是灰度圖 或 3是彩色圖),張量類型(浮點,雙浮點,或字節)])後面兩個參數可選。
  -- preprocess是自定義函數,後面有詳細介紹。
  -- 將大小從[0,1]改爲[0,255];將RGB改爲BGR;減去平均像素值。
  -- local params = cmd:parse(arg)在最後有定義。
  local content_image = image.load(params.content_image, 3)
  local content_image_caffe = preprocess(content_image):float():cuda()
  local content_layers = params.content_layers:split(",")

  -- style: target model image
  local style_image = image.load(params.style_image, 3)
  local style_image_caffe = preprocess(style_image):float():cuda()
  local style_layers = params.style_layers:split(",")

  local c, h, w = content_image:size(1), content_image:size(2), content_image:size(3)
  local _, h2, w2 = style_image:size(1), style_image:size(2), style_image:size(3)
  local index = params.index

  -- segmentation images 
  -- 圖像語義分割,在內容和風格上分別添加語義遮罩,遮罩裏面的物體顏色可以自定義。
  --[
  local content_seg = image.load(params.content_seg, 3)
  content_seg = image.scale(content_seg, w, h, 'bilinear')
  local style_seg = image.load(params.style_seg, 3)
  style_seg = image.scale(style_seg, w2, h2, 'bilinear')
  local color_codes = {'blue', 'green', 'black', 'white', 'red', 'yellow', 'grey', 'lightblue', 'purple'}
  local color_content_masks, color_style_masks = {}, {}
  for j = 1, #color_codes do
    local content_mask_j = ExtractMask(content_seg, color_codes[j])
    local style_mask_j = ExtractMask(style_seg, color_codes[j])
    table.insert(color_content_masks, content_mask_j)
    table.insert(color_style_masks, style_mask_j)
  end 
  --]]

  -- Set up the network, inserting style and content loss modules
  -- 構建網絡,插入風格和內容損失模塊
  local content_losses, style_losses = {}, {}
  local next_content_idx, next_style_idx = 1, 1
  local net = nn.Sequential()

  if params.tv_weight > 0 then
    local tv_mod = nn.TVLoss(params.tv_weight):float():cuda()
    net:add(tv_mod)
  end

  -- load VGG-19 network
  -- 加載VGG-19網絡,並添加風格和內容。
  local cnn = loadcaffe.load(params.proto_file, params.model_file, params.backend):float():cuda()

  paths.mkdir(tostring(params.serial))
  print('Exp serial:', params.serial)

  for i = 1, #cnn do
    if next_content_idx <= #content_layers or next_style_idx <= #style_layers then
      local layer = cnn:get(i)
      local name = layer.name
      local layer_type = torch.type(layer)
      local is_pooling = (layer_type == 'nn.SpatialMaxPooling' or layer_type == 'cudnn.SpatialMaxPooling')
      local is_conv    = (layer_type == 'nn.SpatialConvolution' or layer_type == 'cudnn.SpatialConvolution')

      net:add(layer)

      if is_pooling then
        for k = 1, #color_codes do
          color_content_masks[k] = image.scale(color_content_masks[k], math.ceil(color_content_masks[k]:size(2)/2), math.ceil(color_content_masks[k]:size(1)/2))
          color_style_masks[k]   = image.scale(color_style_masks[k],   math.ceil(color_style_masks[k]:size(2)/2),   math.ceil(color_style_masks[k]:size(1)/2))
        end
      elseif is_conv then
        local sap = nn.SpatialAveragePooling(3,3,1,1,1,1):float()
        for k = 1, #color_codes do
          color_content_masks[k] = sap:forward(color_content_masks[k]:repeatTensor(1,1,1))[1]:clone()
          color_style_masks[k]   = sap:forward(color_style_masks[k]:repeatTensor(1,1,1))[1]:clone()
        end
      end 
      color_content_masks = deepcopy(color_content_masks)
      color_style_masks = deepcopy(color_style_masks)


      if name == content_layers[next_content_idx] then
        print("Setting up content layer", i, ":", layer.name)
        local target = net:forward(content_image_caffe):clone()
        local loss_module = nn.ContentLoss(params.content_weight, target, false):float():cuda()
        net:add(loss_module)
        table.insert(content_losses, loss_module)
        next_content_idx = next_content_idx + 1
      end

     if name == style_layers[next_style_idx] then
        print("Setting up style layer  ", i, ":", layer.name)
        local gram = GramMatrix():float():cuda()
        local target_features = net:forward(style_image_caffe):clone()

        local target_grams = {}

        for j = 1, #color_codes do 
          local l_style_mask_ori = color_style_masks[j]:clone():cuda()
          local l_style_mask = l_style_mask_ori:repeatTensor(1,1,1):expandAs(target_features)
          local l_style_mean = l_style_mask_ori:mean()

          local masked_target_features = torch.cmul(l_style_mask, target_features)
          local masked_target_gram = gram:forward(masked_target_features):clone()
          if l_style_mean > 0 then
            masked_target_gram:div(target_features:nElement() * l_style_mean)
          end 
          table.insert(target_grams, masked_target_gram)
        end 

        local loss_module = nn.StyleLossWithSeg(params.style_weight, target_grams, color_content_masks, color_codes, next_style_idx, false):float():cuda()

        net:add(loss_module)
        table.insert(style_losses, loss_module)
        next_style_idx = next_style_idx + 1
      end 

    end
  end

  -- We don't need the base CNN anymore, so clean it up to save memory.
  cnn = nil
  for i=1,#net.modules do
    local module = net.modules[i]
    if torch.type(module) == 'nn.SpatialConvolutionMM' then
        -- remove these, not used, but uses gpu memory
        module.gradWeight = nil
        module.gradBias = nil
    end
  end
  collectgarbage()

  local mean_pixel = torch.CudaTensor({103.939, 116.779, 123.68})
  local meanImage = mean_pixel:view(3, 1, 1):expandAs(content_image_caffe)

  local img = torch.randn(content_image:size()):float():mul(0.0001):cuda()

  -- Run it through the network once to get the proper size for the gradient
  -- All the gradients will come from the extra loss modules, so we just pass
  -- zeros into the top of the net on the backward pass.
  local y = net:forward(img)
  local dy = img.new(#y):zero()

  -- Declaring this here lets us access it in maybe_print
  local optim_state = {
      maxIter = params.num_iterations,
      tolX = 0, tolFun = -1,
      verbose=true, 
  }

  local function maybe_print(t, loss)
    local verbose = (params.print_iter > 0 and t % params.print_iter == 0)
    if verbose then
      print(string.format('Iteration %d / %d', t, params.num_iterations))
      for i, loss_module in ipairs(content_losses) do
        print(string.format('  Content %d loss: %f', i, loss_module.loss))
      end
      for i, loss_module in ipairs(style_losses) do
        print(string.format('  Style %d loss: %f', i, loss_module.loss))
      end
      print(string.format('  Total loss: %f', loss))
    end
  end

  local function maybe_save(t)
    local should_save = params.save_iter > 0 and t % params.save_iter == 0
    should_save = should_save or t == params.num_iterations
    if should_save then
      local disp = deprocess(img:double())
      disp = image.minmax{tensor=disp, min=0, max=1}
      local filename = params.serial .. '/out' .. tostring(index) .. '_t_' .. tostring(t) .. '.png'
      image.save(filename, disp)
    end
  end

  local num_calls = 0
  local function feval(AffineModel) 
    num_calls = num_calls + 1

    local output = torch.add(img, meanImage)
    local input  = torch.add(content_image_caffe, meanImage)

    net:forward(img)

    local gradient_VggNetwork = net:updateGradInput(img, dy)

    local grad = gradient_VggNetwork

    local loss = 0
    for _, mod in ipairs(content_losses) do
      loss = loss + mod.loss
    end
    for _, mod in ipairs(style_losses) do
      loss = loss + mod.loss
    end
    maybe_print(num_calls, loss)
    maybe_save(num_calls)

    collectgarbage()

    -- optim.lbfgs expects a vector for gradients
    return loss, grad:view(grad:nElement()) 
  end

  -- Run optimization.
  local x, losses = optim.lbfgs(feval, img, optim_state)  
end


function build_filename(output_image, iteration)
  local ext = paths.extname(output_image)
  local basename = paths.basename(output_image, ext)
  local directory = paths.dirname(output_image)
  return string.format('%s/%s_%d.%s',directory, basename, iteration, ext)
end

-- Preprocess an image before passing it to a Caffe model. 
-- We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
-- and subtract the mean pixel.
-- 將圖像傳給Caffe模型之前進行預處理;
-- 需要縮放,從[0,1]變爲[0,255],將RGB轉爲BGR;
-- 減去平均像素值。
function preprocess(img)
  local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68})
  local perm = torch.LongTensor{3, 2, 1}
  img = img:index(1, perm):mul(256.0)
  mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img)
  img:add(-1, mean_pixel)
  return img
end


-- Undo the above preprocessing.
function deprocess(img)
  local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68})
  mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img)
  img = img + mean_pixel
  local perm = torch.LongTensor{3, 2, 1}
  img = img:index(1, perm):div(256.0)
  return img
end

function deepcopy(orig)
    local orig_type = type(orig)
    local copy
    if orig_type == 'table' then
        copy = {}
        for orig_key, orig_value in next, orig, nil do
            copy[deepcopy(orig_key)] = deepcopy(orig_value)
        end
        setmetatable(copy, deepcopy(getmetatable(orig)))
    else -- number, string, boolean, etc
        copy = orig
    end
    return copy
end

-- Define an nn Module to compute content loss in-place
local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module')

function ContentLoss:__init(strength, target, normalize)
  parent.__init(self)
  self.strength = strength
  self.target = target
  self.normalize = normalize or false
  self.loss = 0
  self.crit = nn.MSECriterion()
end

function ContentLoss:updateOutput(input)
  if input:nElement() == self.target:nElement() then
    self.loss = self.crit:forward(input, self.target) * self.strength
  else
    print('WARNING: Skipping content loss')
  end
  self.output = input
  return self.output
end

function ContentLoss:updateGradInput(input, gradOutput)
  if input:nElement() == self.target:nElement() then
    self.gradInput = self.crit:backward(input, self.target)
  end
  if self.normalize then
    self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8)
  end
  self.gradInput:mul(self.strength)
  self.gradInput:add(gradOutput)
  return self.gradInput
end

-- Returns a network that computes the CxC Gram matrix from inputs
-- of size C x H x W
function GramMatrix()
  local net = nn.Sequential()
  net:add(nn.View(-1):setNumInputDims(2))
  local concat = nn.ConcatTable()
  concat:add(nn.Identity())
  concat:add(nn.Identity())
  net:add(concat)
  net:add(nn.MM(false, true))
  return net
end


-- Define an nn Module to compute style loss in-place
local StyleLoss, parent = torch.class('nn.StyleLoss', 'nn.Module')

function StyleLoss:__init(strength, target, normalize)
  parent.__init(self)
  self.normalize = normalize or false
  self.strength = strength
  self.target = target
  self.loss = 0

  self.gram = GramMatrix()
  self.G = nil
  self.crit = nn.MSECriterion()
end

function StyleLoss:updateOutput(input)
  self.G = self.gram:forward(input)
  self.G:div(input:nElement())
  self.loss = self.crit:forward(self.G, self.target)
  self.loss = self.loss * self.strength
  self.output = input
  return self.output
end

function StyleLoss:updateGradInput(input, gradOutput)
  local dG = self.crit:backward(self.G, self.target)
  dG:div(input:nElement())
  self.gradInput = self.gram:backward(input, dG)
  if self.normalize then
    self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8)
  end
  self.gradInput:mul(self.strength)
  self.gradInput:add(gradOutput)
  return self.gradInput
end


function ExtractMask(seg, color)
  local mask = nil
  if color == 'green' then 
    mask = torch.lt(seg[1], 0.1)
    mask:cmul(torch.gt(seg[2], 1-0.1))
    mask:cmul(torch.lt(seg[3], 0.1))
  elseif color == 'black' then 
    mask = torch.lt(seg[1], 0.1)
    mask:cmul(torch.lt(seg[2], 0.1))
    mask:cmul(torch.lt(seg[3], 0.1))
  elseif color == 'white' then
    mask = torch.gt(seg[1], 1-0.1)
    mask:cmul(torch.gt(seg[2], 1-0.1))
    mask:cmul(torch.gt(seg[3], 1-0.1))
  elseif color == 'red' then 
    mask = torch.gt(seg[1], 1-0.1)
    mask:cmul(torch.lt(seg[2], 0.1))
    mask:cmul(torch.lt(seg[3], 0.1))
  elseif color == 'blue' then
    mask = torch.lt(seg[1], 0.1)
    mask:cmul(torch.lt(seg[2], 0.1))
    mask:cmul(torch.gt(seg[3], 1-0.1))
  elseif color == 'yellow' then
    mask = torch.gt(seg[1], 1-0.1)
    mask:cmul(torch.gt(seg[2], 1-0.1))
    mask:cmul(torch.lt(seg[3], 0.1))
  elseif color == 'grey' then 
    mask = torch.cmul(torch.gt(seg[1], 0.5-0.1), torch.lt(seg[1], 0.5+0.1))
    mask:cmul(torch.cmul(torch.gt(seg[2], 0.5-0.1), torch.lt(seg[2], 0.5+0.1)))
    mask:cmul(torch.cmul(torch.gt(seg[3], 0.5-0.1), torch.lt(seg[3], 0.5+0.1)))
  elseif color == 'lightblue' then
    mask = torch.lt(seg[1], 0.1)
    mask:cmul(torch.gt(seg[2], 1-0.1))
    mask:cmul(torch.gt(seg[3], 1-0.1))
  elseif color == 'purple' then 
    mask = torch.gt(seg[1], 1-0.1)
    mask:cmul(torch.lt(seg[2], 0.1))
    mask:cmul(torch.gt(seg[3], 1-0.1))
  else 
    print('ExtractMask(): color not recognized, color = ', color)
  end 
  return mask:float()
end

-- Define style loss with segmentation 
local StyleLossWithSeg, parent = torch.class('nn.StyleLossWithSeg', 'nn.Module')

--function StyleLossWithSeg:__init(strength, target_grams, color_content_masks, content_seg_idxs, layer_id, normalize)
function StyleLossWithSeg:__init(strength, target_grams, color_content_masks, color_codes, layer_id, normalize)
  parent.__init(self)
  self.strength = strength
  self.target_grams = target_grams
  self.color_content_masks = deepcopy(color_content_masks)
  self.color_codes = color_codes
  --self.content_seg_idxs = content_seg_idxs
  self.normalize = normalize

  self.loss = 0
  self.gram = GramMatrix()
  self.crit = nn.MSECriterion()

  self.layer_id = layer_id
end 

function StyleLossWithSeg:updateOutput(input)
  self.output = input
  return self.output
end 

function StyleLossWithSeg:updateGradInput(input, gradOutput)
  self.loss = 0
  self.gradInput = gradOutput:clone()
  self.gradInput:zero()
  for j = 1, #self.color_codes do 
    local l_content_mask_ori = self.color_content_masks[j]:clone():cuda()
    local l_content_mask = l_content_mask_ori:repeatTensor(1,1,1):expandAs(input) 
    local l_content_mean = l_content_mask_ori:mean()

    local masked_input_features = torch.cmul(l_content_mask, input)
    local masked_input_gram = self.gram:forward(masked_input_features):clone()
    if l_content_mean > 0 then 
      masked_input_gram:div(input:nElement() * l_content_mean)
    end

    local loss_j = self.crit:forward(masked_input_gram, self.target_grams[j])
    loss_j = loss_j * self.strength * l_content_mean
    self.loss = self.loss + loss_j

    local dG = self.crit:backward(masked_input_gram, self.target_grams[j])

    dG:div(input:nElement())

    local gradient = self.gram:backward(masked_input_features, dG) 

    if self.normalize then 
      gradient:div(torch.norm(gradient, 1) + 1e-8)
    end

    self.gradInput:add(gradient)
  end   

  self.gradInput:mul(self.strength)
  self.gradInput:add(gradOutput)
  return self.gradInput
end 


local TVLoss, parent = torch.class('nn.TVLoss', 'nn.Module')

function TVLoss:__init(strength)
  parent.__init(self)
  self.strength = strength
  self.x_diff = torch.Tensor()
  self.y_diff = torch.Tensor()
end

function TVLoss:updateOutput(input)
  self.output = input
  return self.output
end

-- TV loss backward pass inspired by kaishengtai/neuralart
function TVLoss:updateGradInput(input, gradOutput)
  self.gradInput:resizeAs(input):zero()
  local C, H, W = input:size(1), input:size(2), input:size(3)
  self.x_diff:resize(3, H - 1, W - 1)
  self.y_diff:resize(3, H - 1, W - 1)
  self.x_diff:copy(input[{{}, {1, -2}, {1, -2}}])
  self.x_diff:add(-1, input[{{}, {1, -2}, {2, -1}}])
  self.y_diff:copy(input[{{}, {1, -2}, {1, -2}}])
  self.y_diff:add(-1, input[{{}, {2, -1}, {1, -2}}])
  self.gradInput[{{}, {1, -2}, {1, -2}}]:add(self.x_diff):add(self.y_diff)
  self.gradInput[{{}, {1, -2}, {2, -1}}]:add(-1, self.x_diff)
  self.gradInput[{{}, {2, -1}, {1, -2}}]:add(-1, self.y_diff)
  self.gradInput:mul(self.strength)
  self.gradInput:add(gradOutput)
  return self.gradInput
end

function TVGradient(input, gradOutput, strength)
  local C, H, W = input:size(1), input:size(2), input:size(3)
  local gradInput = torch.CudaTensor(C, H, W):zero()
  local x_diff = torch.CudaTensor()
  local y_diff = torch.CudaTensor()
  x_diff:resize(3, H - 1, W - 1)
  y_diff:resize(3, H - 1, W - 1)
  x_diff:copy(input[{{}, {1, -2}, {1, -2}}])
  x_diff:add(-1, input[{{}, {1, -2}, {2, -1}}])
  y_diff:copy(input[{{}, {1, -2}, {1, -2}}])
  y_diff:add(-1, input[{{}, {2, -1}, {1, -2}}])
  gradInput[{{}, {1, -2}, {1, -2}}]:add(x_diff):add(y_diff)
  gradInput[{{}, {1, -2}, {2, -1}}]:add(-1, x_diff)
  gradInput[{{}, {2, -1}, {1, -2}}]:add(-1, y_diff)
  gradInput:mul(strength)
  gradInput:add(gradOutput)
  return gradInput
end 

--cmd:parse(arg)解析輸入參數;
-- Cmd類的[table]parse(arg)解析一個給定的表,arg是lua創建的缺省的參數表;
-- 返回一個可選值的表。
local params = cmd:parse(arg)
main(params)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章