From 83a5f17e9f4fa469a3e910d0ec75833239d7ffdf Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Tue, 22 Aug 2017 17:32:14 -0400 Subject: [PATCH] fix for latest torch --- PrintSize.lua | 36 ------------------------------------ SequencerCriterion.lua | 2 +- init.lua | 5 +++-- 3 files changed, 4 insertions(+), 39 deletions(-) delete mode 100644 PrintSize.lua diff --git a/PrintSize.lua b/PrintSize.lua deleted file mode 100644 index 1f1c64e..0000000 --- a/PrintSize.lua +++ /dev/null @@ -1,36 +0,0 @@ -local PrintSize, parent = torch.class('nn.PrintSize', 'nn.Module') - -function PrintSize:__init(prefix) - parent.__init(self) - self.prefix = prefix or "PrintSize" -end - -function PrintSize:updateOutput(input) - self.output = input - local size - if torch.type(input) == 'table' then - size = input - elseif torch.type(input) == 'nil' then - size = 'missing size' - else - size = input:size() - end - print(self.prefix..":input\n", size) - return self.output -end - - -function PrintSize:updateGradInput(input, gradOutput) - local size - if torch.type(gradOutput) == 'table' then - size = gradOutput - elseif torch.type(gradOutput) == 'nil' then - size = 'missing size' - else - size = gradOutput:size() - end - print(self.prefix..":gradOutput\n", size) - self.gradInput = gradOutput - return self.gradInput -end - diff --git a/SequencerCriterion.lua b/SequencerCriterion.lua index a1a6b99..adf777d 100644 --- a/SequencerCriterion.lua +++ b/SequencerCriterion.lua @@ -53,7 +53,7 @@ function SequencerCriterion:updateGradInput(input, target) end if self.sizeAverage then - nn.utils.recursiveDiv(tableGradInput[i], seqlen) + nn.utils.recursiveDiv(tableGradInput, seqlen) end if torch.isTensor(input) then diff --git a/init.lua b/init.lua index eb127e2..6ac3f4e 100644 --- a/init.lua +++ b/init.lua @@ -14,7 +14,7 @@ function nn.require(packagename) assert(torch.type(packagename) == 'string') local success, message = pcall(function() require(packagename) end) if not success then - print("missing package "..packagename..": run 'luarocks install nnx'") + print("missing package "..packagename..": run 'luarocks install '"..packagename.."'") error(message) end end @@ -22,7 +22,8 @@ end -- c lib: require "paths" -paths.require 'librnn' +pcall(function() paths.require 'librnn' end) -- Not sure why this works... +pcall(function() paths.require 'librnn' end) unpack = unpack or table.unpack