Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to apply softmax to a batch of data with variable length #1297

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions LenSoftMax.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
local LenSoftMax, parent = torch.class('nn.LenSoftMax', 'nn.Module')

function LenSoftMax:__init()
parent.__init(self)
self.gradInput = {torch.Tensor()}
end

function LenSoftMax:updateOutput(input)
local _input, _len = unpack(input)
_input.THNN.LenSoftMax_updateOutput(
_input:cdata(),
self.output:cdata(),
_len:cdata()
)
return self.output
end

function LenSoftMax:updateGradInput(input, gradOutput)
local _input, _len = unpack(input)
_input.THNN.LenSoftMax_updateGradInput(
_input:cdata(),
gradOutput:cdata(),
self.gradInput[1]:cdata(),
self.output:cdata(),
_len:cdata()
)
if not self.gradInput[2] then
self.gradInput[2] = _len.new()
end
self.gradInput[2]:resizeAs(_len):zero()
return self.gradInput
end
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ require('nn.LogSigmoid')
require('nn.LogSoftMax')
require('nn.Sigmoid')
require('nn.SoftMax')
require('nn.LenSoftMax')
require('nn.SoftMin')
require('nn.SoftPlus')
require('nn.SoftSign')
Expand Down
116 changes: 116 additions & 0 deletions lib/THNN/generic/LenSoftMax.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/LenSoftMax.c"
#else

void THNN_(LenSoftMax_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
THIndexTensor *len)
{
if ((input->nDimension != 2) && (len->nDimension != 1))
{
THArgCheck(0, 2, "2D tensor expected for input, 1D tensor expected for len");
}

real *input_data, *output_data;
THIndex_t *len_data;
ptrdiff_t nframe = input->size[0], dim = input->size[1];
ptrdiff_t t;

input = THTensor_(newContiguous)(input);
THTensor_(resizeAs)(output, input);

input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
len_data = THIndexTensor_(data)(len);

#pragma omp parallel for private(t)
for (t = 0; t < nframe; t++)
{
real *input_ptr = input_data + t*dim;
real *output_ptr = output_data + t*dim;

real inputMax = -THInf;
accreal sum;

ptrdiff_t d, ld = (ptrdiff_t)len_data[t];
for (d = 0; d < ld; d++)
{
if (input_ptr[d] >= inputMax) inputMax = input_ptr[d];
}

sum = 0;
for (d = 0; d < ld; d++)
{
real z = exp(input_ptr[d] - inputMax);
output_ptr[d] = z;
sum += z;
}
for (d = ld; d < dim; d++)
{
output_ptr[d] = 0;
}

for (d = 0; d < ld; d++)
{
output_ptr[d] *= 1/sum;
}
}

THTensor_(free)(input);
}

void THNN_(LenSoftMax_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
THTensor *output,
THIndexTensor *len)
{
THNN_CHECK_SHAPE(input, gradOutput);

if ((output->nDimension != 2) && (len->nDimension != 1))
{
THError("2D tensor expected for input, 1D tensor expected for len");
}

real *gradInput_data, *gradOutput_data, *output_data;
THIndex_t *len_data;
ptrdiff_t nframe = output->size[0], dim = output->size[1];
ptrdiff_t t;

gradOutput = THTensor_(newContiguous)(gradOutput);
output = THTensor_(newContiguous)(output);

THTensor_(resizeAs)(gradInput, output);
gradInput_data = THTensor_(data)(gradInput);
output_data = THTensor_(data)(output);
gradOutput_data = THTensor_(data)(gradOutput);
len_data = THIndexTensor_(data)(len);

#pragma omp parallel for private(t)
for (t = 0; t < nframe; t++)
{
real *gradInput_ptr = gradInput_data + t*dim;
real *output_ptr = output_data + t*dim;
real *gradOutput_ptr = gradOutput_data + t*dim;

ptrdiff_t d, ld = (ptrdiff_t)len_data[t];
accreal sum = 0;
for (d = 0; d < ld; d++)
sum += (accreal)gradOutput_ptr[d] * output_ptr[d];

for (d = 0; d < ld; d++)
gradInput_ptr[d] = output_ptr[d] * (gradOutput_ptr[d] - sum);

for (d = ld; d < dim; d++)
gradInput_ptr[d] = 0;
}

THTensor_(free)(gradOutput);
THTensor_(free)(output);
}

#endif
13 changes: 13 additions & 0 deletions lib/THNN/generic/THNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,19 @@ TH_API void THNN_(SoftMax_updateGradInput)(
THTensor *gradInput,
THTensor *output);

TH_API void THNN_(LenSoftMax_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
THIndexTensor *len);
TH_API void THNN_(LenSoftMax_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
THTensor *output,
THIndexTensor *len);

TH_API void THNN_(SoftPlus_updateOutput)(
THNNState *state,
THTensor *input,
Expand Down
3 changes: 3 additions & 0 deletions lib/THNN/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@
#include "generic/SoftMax.c"
#include "THGenerateFloatTypes.h"

#include "generic/LenSoftMax.c"
#include "THGenerateFloatTypes.h"

#include "generic/SoftPlus.c"
#include "THGenerateFloatTypes.h"

Expand Down