forked from torch/cunn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
THCUNN.lua
151 lines (128 loc) · 4.27 KB
/
THCUNN.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
local ffi = require 'ffi'
local THNN = require 'nn.THNN'
local THCUNN = {}
-- load libTHCUNN
THCUNN.C = ffi.load(package.searchpath('libTHCUNN', package.cpath))
-- load THC
local THC = ffi.os == 'Windows' and ffi.load('THC') or ffi.C
local THCState_ptr = ffi.typeof('THCState*')
function THCUNN.getState()
return THCState_ptr(cutorch.getState());
end
local THCUNN_generic_h = require 'cunn.THCUNN_generic_h'
-- strip all lines starting with #
-- to remove preprocessor directives originally present
-- in THNN.h
THCUNN_generic_h = THCUNN_generic_h:gsub("\n#[^\n]*", "")
THCUNN_generic_h = THCUNN_generic_h:gsub("^#[^\n]*\n", "")
local preprocessed_generic = string.gsub(THCUNN_generic_h, 'TH_API void THNN_%(([%a%d_]+)%)', 'void THNN_TYPE%1')
local replacements =
{
{
['THTensor'] = 'THCudaTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['THIndex_t'] = 'long',
['THInteger_t'] = 'float'
}
}
local cct2lt = {
['THCudaFloatTensor'] = 'torch.CudaTensor',
['THCudaDoubleTensor'] = 'torch.CudaDoubleTensor',
}
local replacements_generic =
{
{
['THCTensor'] = 'THCudaTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'Cuda',
['accreal'] = 'float',
},
{
['THCTensor'] = 'THCudaDoubleTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'CudaDouble',
['accreal'] = 'double',
}
}
if cutorch.hasHalf then
ffi.cdef("half THC_float2half(float a);")
ffi.cdef("float THC_half2float(half a);")
cct2lt['THCudaHalfTensor'] = 'torch.CudaHalfTensor'
local half_replacement = {
['THCTensor'] = 'THCudaHalfTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'CudaHalf',
['accreal'] = 'float',
}
table.insert(replacements_generic, half_replacement)
end
for i=1,#replacements_generic do
local r = replacements_generic[i]
local s = preprocessed_generic
for k,v in pairs(r) do
s = string.gsub(s, k, v)
end
ffi.cdef(s)
end
local function extract_function_names_generic(s)
local t = {}
for n in string.gmatch(s, 'TH_API void THNN_%(([%a%d_]+)%)') do
t[#t+1] = n
end
return t
end
local function find_positions(s, p)
local begin = 0
local positions = {}
while true do
local start, stop = string.find(s, p, begin)
if (start == nil) then break end
positions[#positions+1] = start
begin = stop + 1
end
return positions
end
local function extract_function_names_and_real_args(s)
local t = {}
for n in string.gmatch(s, 'TH_API ([^;]+)') do
local func_name = string.match(n, 'void THNN_%(([%a%d_]+)%)')
local param_positions = find_positions(n, ',')
local positions = {}
for x,y in ipairs(find_positions(n, 'real')) do
local found = false
for cn,cp in ipairs(param_positions) do
if cp > y then
positions[#positions+1] = cn
found = true
break
end
end
-- it is the last param
if not found then positions[#positions+1] = #param_positions + 1 end
end
t[func_name] = positions
end
return t
end
local real_args = extract_function_names_and_real_args(THCUNN_generic_h)
-- build function table
local function_names_generic = extract_function_names_generic(THCUNN_generic_h)
THNN.kernels['torch.CudaTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'Cuda', THCUNN.getState)
torch.getmetatable('torch.CudaTensor').THNN = THNN.kernels['torch.CudaTensor']
THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'CudaDouble', THCUNN.getState)
torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor']
if cutorch.hasHalf then
local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
end
local function Module__converter(type)
return function(self)
return self:type(type)
end
end
rawset(torch.getmetatable('nn.Module'), 'cudaDouble', Module__converter('torch.CudaDoubleTensor'))
if cutorch.hasHalf then
rawset(torch.getmetatable('nn.Module'), 'cudaHalf', Module__converter('torch.CudaHalfTensor'))
end
return THCUNN