-
Notifications
You must be signed in to change notification settings - Fork 128
conv fuse #288
base: master
Are you sure you want to change the base?
conv fuse #288
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import {glslRelu, glslSigmoid} from './unary-op'; | ||
|
||
export function getActicationSnippet(activation: string) { | ||
let activationFunction = ''; | ||
let activationName = ''; | ||
switch (activation) { | ||
case 'Relu': | ||
activationName = glslRelu().name; | ||
activationFunction = glslRelu().body; | ||
break; | ||
case 'Sigmoid': | ||
activationName = glslSigmoid().name; | ||
activationFunction = glslSigmoid().body; | ||
break; | ||
default: | ||
activationName = ''; | ||
activationFunction = ''; | ||
} | ||
const applyActivation = activation ? ` | ||
value = ${activationName}(value);` : | ||
''; | ||
return {activationFunction, applyActivation}; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,8 @@ export declare namespace Graph { | |
export interface Transformer { | ||
removeAllIdentityNodes(): void; | ||
removeAllDropoutNodes(): void; | ||
|
||
fuseConvActivationNodes(): void; | ||
// TODO: add generic functions to manipulate the graph | ||
} | ||
|
||
|
@@ -559,6 +561,7 @@ class GraphImpl implements Graph, Graph.Transformer { | |
// apply common transform | ||
this.removeAllIdentityNodes(); | ||
this.removeAllDropoutNodes(); | ||
this.fuseConvActivationNodes(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we apply fusion here, wouldn't it fail all backends except webgl? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's my guess too. So I submit this PR and want to check how the CI goes. Will look into the result tomorrow. |
||
|
||
// apply initializer specific transform | ||
if (graphInitializer) { | ||
|
@@ -736,4 +739,27 @@ class GraphImpl implements Graph, Graph.Transformer { | |
nodeIndex++; | ||
} | ||
} | ||
|
||
isActivation(n: Node): boolean { | ||
switch (n.opType) { | ||
// TODO: add other activation methods | ||
case 'Relu': | ||
case 'Sigmoid': | ||
return true; | ||
default: | ||
return false; | ||
} | ||
} | ||
|
||
fuseConvActivationNodes() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically, it's not enough to traverse the graph only once. Consider a graph with pattern Conv+Relu+Sigmoid. After one iteration, it's transformed to Conv+Sigmoid (as Relu is fused). Ideally, it should further fuse Conv and Sigmoid. We may want to keep running the transformer until nothing can be transformed. It's fine to leave it for future work as we don't have immediate requirements for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Theoretically you are right, We need to keep checking until no further fusing can take place. In practice, I am not sure if a model would concat multiple activations in a row. It seems redundant for following activations once the signal is 'activated' after the first one. Maybe there is a user case but it shouldn't be very common I guess.. The implementation won't be hard. We just need to insert possibly several internal attributes, one for an activation function. And loop them through in conv's shader gen. For the sake of simplicity, I'll keep it as it is for now. And if we do see a need in the future, we can add this logic in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right that the example I used above was artificial. What I suggested is kind of the standard approach for graph level optimization, which is also used by ORT. In other common situations like matmul+add+add (=> gemm+add => gemm), this approach is needed. But our current fusion transformer covers only two activations, so we don't have to do it now. |
||
for (const node of this._nodes) { | ||
if (node.opType === 'Conv') { | ||
const next = this._allData[node.outputs[0]]._to; | ||
if (next.length === 1 && this.isActivation(this._nodes[next[0]])) { | ||
node.attributes.set('__internal_activation', 'string', (this._nodes[next[0]].opType)); | ||
this.deleteNode(next[0]); | ||
} | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for my education, why blend cannot co-exist with fusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any indication on perf by disabling blend while enabling activation?