Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

conv fuse #288

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/backends/webgl/ops/conv-pack.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Attribute} from '../../../attribute';
import {Logger} from '../../../instrument';
import {Conv} from '../../../ops/conv';
import {Tensor} from '../../../tensor';
Expand Down Expand Up @@ -36,6 +37,11 @@ export class WebGLConvPacked extends Conv {
const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides);
const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides);
const matmul = new WebGLMatMulPacked();
if (!!this.activation) {
const attributes = new Attribute(undefined);
attributes.set('__internal_activation', 'string', (this.activation));
matmul.initialize(attributes);
}
const reshape = new WebGLReshapePacked();
// shape for kernel reshape
const shape =
Expand Down
37 changes: 23 additions & 14 deletions lib/backends/webgl/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../types';
import {WebGLContext} from '../webgl-context';

import {WebGLConvPacked} from './conv-pack';
import {getActicationSnippet} from './fuse_utils';

export class WebGLConv extends Conv {
unpackedGroupedConvImpl: WebGLUnpackedGroupedConv;
Expand Down Expand Up @@ -66,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {

createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const hasBias = inputs.length > 2;
const processBias = hasBias ? `dotProd += getBias(output_channel);` : ``;
const processBias = hasBias ? `value += getBias(output_channel);` : ``;
const xShape = inputs[0].dims.slice();
const wShape = inputs[1].dims.slice();
const outputChannelsPerGroup = wShape[0] / this.group;
Expand All @@ -85,18 +87,20 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
const outputShape = WebGLConv.calcOutputShape(xShape, wShape, this.dilations, this.pads, this.strides);
const glsl = getGlsl(handler.session.backend.glContext.version);

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);

const shaderSource = `
const ivec2 strides = ivec2(${this.strides[0]}, ${this.strides[1]});
const ivec2 pads = ivec2(${this.pads[0]}, ${this.pads[1]});

${activationFunction}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
int output_channel = coords.y;
ivec2 xRCCorner = coords.zw * strides - pads;
int group_id = output_channel / ${outputChannelsPerGroup};

float dotProd = 0.0;
float value = 0.0;
for (int wInChannel = 0; wInChannel < ${wShape[1]}; wInChannel++) {
int input_channel = group_id * ${wShape[1]} + wInChannel;
for (int wHeight = 0; wHeight < ${wShape[2]}; wHeight++) {
Expand All @@ -114,12 +118,13 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {

float xVal = getX(batch, input_channel, xWidth, xHeight);
float wVal = getW(output_channel, wInChannel, wWidth, wHeight);
dotProd += xVal*wVal;
value += xVal*wVal;
}
}
}
${processBias}
${glsl.output} = vec4(dotProd, .0, .0, .0);
${applyActivation}
${glsl.output} = vec4(value, .0, .0, .0);
}
`;
return {
Expand Down Expand Up @@ -215,7 +220,6 @@ export class WebGLUnpackedConv extends Conv {
let blend = false;
for (let k = 0; k < sharedDim; k += sharedDimReadSize) {
Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`);

if (k === sharedDimReadSize) {
blend = true;
gl.enable(gl.BLEND);
Expand Down Expand Up @@ -248,6 +252,7 @@ export class WebGLUnpackedConv extends Conv {
const im2colDims = WebGLUnpackedConv.calcIm2ColDims(xshape, kshape, outputShape, 4);
const outputLayout = inferenceHandler.createTextureLayoutFromShape(
im2colDims, 4, [im2colDims[0], im2colDims[1], im2colDims[2], im2colDims[3] * 4], {breakAxis: 3});

const shaderSource = `
const int XC = ${xshape[1]};
const int XH = ${xshape[2]};
Expand All @@ -263,13 +268,12 @@ export class WebGLUnpackedConv extends Conv {
const int KHKW = KH*KW;
const int XCKHKW = XC * KHKW;
const int outputChannels = 4;

vec4 process(int indices[${rank}]) {
int b = indices[0]; // batch size
int oh = indices[1] * strideH - padH; //output height
int ow = indices[2] * strideW - padW; //output width
int p = indices[3] * outputChannels; //patch
vec4 v = vec4(0.0);
vec4 value = vec4(0.0);
for(int i=0; i < outputChannels; ++i) {
if(p < XCKHKW) {
int patchC = p / KHKW;
Expand All @@ -286,12 +290,12 @@ export class WebGLUnpackedConv extends Conv {
xh2 < XH &&
xw2 >= 0 &&
xw2 < XW) {
v[i] = _X(x);
value[i] = _X(x);
}
}
++p;
}
return v;
return value;
}
`;
return {
Expand Down Expand Up @@ -321,16 +325,20 @@ export class WebGLUnpackedConv extends Conv {
const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape);
const initValue = (inputs.length < 3) ? '0.0' : '_B(b)';
const sharedDim = im2colLayout.shape[3];
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported;
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported && !this.activation;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for my education, why blend cannot co-exist with fusion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any indication on perf by disabling blend while enabling activation?

const sharedDimReadSize = blendEnabled && inferenceHandler.session.backend.matmulMaxBatchSize ?
this.calcSharedDimReadSize(inferenceHandler.session.backend.matmulMaxBatchSize, sharedDim) :
sharedDim;
const samplers = ['Im2Col', 'K'];
if (inputs.length === 3) {
samplers.push('B');
}

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);

const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
${activationFunction}
float process(int indices[${rank}]) {
int b[1];
b[0] = indices[1];
Expand All @@ -341,15 +349,16 @@ export class WebGLUnpackedConv extends Conv {
int im2colOffset = im2col[0] * ${im2colLayout.strides[0]} + im2col[1] * ${
im2colLayout.strides[1]} + im2col[2] * ${im2colLayout.strides[2]} + sharedDimOffset;
int kernelOffset = indices[1] * ${kLayout.strides[0]} + sharedDimOffset;
float sum = sharedDimOffset == 0 ? ${initValue} : 0.0;
float value = sharedDimOffset == 0 ? ${initValue} : 0.0;
for (int i = 0; i < ${sharedDimReadSize}; ++i) {
vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colLayout.width}, ${im2colLayout.height});
vec2 kernelCoords = offsetToCoords(kernelOffset, ${kLayout.width}, ${kLayout.height});
sum += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
value += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
++im2colOffset;
++kernelOffset;
}
return sum;
${applyActivation}
return value;
}`;
return {
inputLayouts: inputs.length === 3 ? [im2colLayout, kLayout, bLayout!] : [im2colLayout, kLayout],
Expand Down
23 changes: 23 additions & 0 deletions lib/backends/webgl/ops/fuse_utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import {glslRelu, glslSigmoid} from './unary-op';

export function getActicationSnippet(activation: string) {
let activationFunction = '';
let activationName = '';
switch (activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
activationName = '';
activationFunction = '';
}
const applyActivation = activation ? `
value = ${activationName}(value);` :
'';
return {activationFunction, applyActivation};
}
5 changes: 5 additions & 0 deletions lib/backends/webgl/ops/matmul-pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {Tensor} from '../../../tensor';
import {BroadcastUtil} from '../../../util';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';
import {getActicationSnippet} from './fuse_utils';

export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
Expand All @@ -25,8 +26,11 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
const aRank = aShape.length;
const bRank = bShape.length;
const sharedDim = aShape[aShape.length - 1];

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);
// TODO:fix broadcasting
const shaderSource = `
${activationFunction}
vec4 process(int indices[${rank}]) {
int a[${aRank}];
int b[${bRank}];
Expand All @@ -41,6 +45,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
value += ${getA(aRank)}.ggaa * ${getB(bRank)}.baba;
}
${processBias}
${applyActivation}
return value;
}`;
return {
Expand Down
26 changes: 26 additions & 0 deletions lib/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ export declare namespace Graph {
export interface Transformer {
removeAllIdentityNodes(): void;
removeAllDropoutNodes(): void;

fuseConvActivationNodes(): void;
// TODO: add generic functions to manipulate the graph
}

Expand Down Expand Up @@ -559,6 +561,7 @@ class GraphImpl implements Graph, Graph.Transformer {
// apply common transform
this.removeAllIdentityNodes();
this.removeAllDropoutNodes();
this.fuseConvActivationNodes();
Copy link
Member

@duli2012 duli2012 Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we apply fusion here, wouldn't it fail all backends except webgl?
I'm a little surprised that browser stack CI got passed. did I miss anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's my guess too. So I submit this PR and want to check how the CI goes. Will look into the result tomorrow.


// apply initializer specific transform
if (graphInitializer) {
Expand Down Expand Up @@ -736,4 +739,27 @@ class GraphImpl implements Graph, Graph.Transformer {
nodeIndex++;
}
}

isActivation(n: Node): boolean {
switch (n.opType) {
// TODO: add other activation methods
case 'Relu':
case 'Sigmoid':
return true;
default:
return false;
}
}

fuseConvActivationNodes() {
Copy link
Member

@duli2012 duli2012 Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, it's not enough to traverse the graph only once. Consider a graph with pattern Conv+Relu+Sigmoid. After one iteration, it's transformed to Conv+Sigmoid (as Relu is fused). Ideally, it should further fuse Conv and Sigmoid. We may want to keep running the transformer until nothing can be transformed. It's fine to leave it for future work as we don't have immediate requirements for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically you are right, We need to keep checking until no further fusing can take place. In practice, I am not sure if a model would concat multiple activations in a row. It seems redundant for following activations once the signal is 'activated' after the first one. Maybe there is a user case but it shouldn't be very common I guess..

The implementation won't be hard. We just need to insert possibly several internal attributes, one for an activation function. And loop them through in conv's shader gen. For the sake of simplicity, I'll keep it as it is for now. And if we do see a need in the future, we can add this logic in.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that the example I used above was artificial. What I suggested is kind of the standard approach for graph level optimization, which is also used by ORT. In other common situations like matmul+add+add (=> gemm+add => gemm), this approach is needed. But our current fusion transformer covers only two activations, so we don't have to do it now.

for (const node of this._nodes) {
if (node.opType === 'Conv') {
const next = this._allData[node.outputs[0]]._to;
if (next.length === 1 && this.isActivation(this._nodes[next[0]])) {
node.attributes.set('__internal_activation', 'string', (this._nodes[next[0]].opType));
this.deleteNode(next[0]);
}
}
}
}
}
2 changes: 2 additions & 0 deletions lib/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export abstract class Conv implements Operator {
this.kernelShape = attributes.getInts('kernel_shape', []);
this.pads = attributes.getInts('pads', [0, 0, 0, 0]);
this.strides = attributes.getInts('strides', [1, 1]);
this.activation = attributes.getString('__internal_activation', '');
}

checkInputs(inputs: Tensor[]): boolean {
Expand Down Expand Up @@ -88,4 +89,5 @@ export abstract class Conv implements Operator {
protected kernelShape: number[];
protected pads: number[];
protected strides: number[];
protected activation: string;
}
5 changes: 4 additions & 1 deletion lib/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import {Tensor} from '../tensor';
export abstract class MatMul implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(attributes: Attribute): void {}
initialize(attributes: Attribute): void {
this.activation = attributes.getString('__internal_activation', '');
}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length !== 2) {
Expand Down Expand Up @@ -38,4 +40,5 @@ export abstract class MatMul implements Operator {

return true;
}
protected activation: string;
}
Binary file not shown.
Binary file not shown.
Binary file added test/data/teams_model/msra_190729.onnx
Binary file not shown.
Binary file added test/data/teams_model/test_data_set_0/input_0.pb
Binary file not shown.
Binary file not shown.