-
Notifications
You must be signed in to change notification settings - Fork 29
/
optimize-fconv.lua
38 lines (32 loc) · 1.03 KB
/
optimize-fconv.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the license found in the LICENSE file in
-- the root directory of this source tree. An additional grant of patent rights
-- can be found in the PATENTS file in the same directory.
--
--[[
--
-- Optimize a fconv model for fast generation.
--
--]]
require 'fairseq'
local cmd = torch.CmdLine()
cmd:option('-input_model', 'fconv_model.th7',
'a th7 file that contains a fconv model')
cmd:option('-output_model', 'fconv_model_opt.th7',
'an output file that will contain an optimized version')
local config = cmd:parse(arg)
local model = torch.load(config.input_model)
if torch.typename(model) ~= 'FConvModel' then
error '"FConvModel" expected'
end
-- Enable faster decoding
model:makeDecoderFast()
-- Clear output buffers and zero gradients for better compressability
model.module:clearState()
local _, gparams = model.module:parameters()
for i = 1, #gparams do
gparams[i]:zero()
end
torch.save(config.output_model, model)