Skip to content

Commit

Permalink
Don't output derivative lines for constant nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
davepagurek committed Jun 9, 2023
1 parent a284c15 commit a6037ca
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 13 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@davepagurek/glsl-autodiff",
"version": "0.0.19",
"version": "0.0.20",
"main": "build/autodiff.js",
"author": "Dave Pagurek <[email protected]>",
"license": "MIT",
Expand Down
10 changes: 9 additions & 1 deletion src/arithmetic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ export class Mult extends Op {
}
derivative(param: Param) {
const [f, g] = this.dependsOn
return `${f.ref()}*${g.derivRef(param)}+${g.ref()}*${f.derivRef(param)}`
const fIsConst = f.isConst(param)
const gIsConst = g.isConst(param)
if (fIsConst && !gIsConst) {
return `${f.ref()}*${g.derivRef(param)}`
} else if (!fIsConst && gIsConst) {
return `${g.ref()}*${f.derivRef(param)}`
} else {
return `${f.ref()}*${g.derivRef(param)}+${g.ref()}*${f.derivRef(param)}`
}
}
}

Expand Down
29 changes: 21 additions & 8 deletions src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,14 @@ export abstract class Op {
}
}

public zeroDerivative() {
return '0.0'
}

public derivRef(param: Param): string {
if (this.useTempVar()) {
if (this.isConst(param)) {
return this.zeroDerivative()
} else if (this.useTempVar()) {
return `_glslad_dv${this.id}_d${param.safeName()}`
} else {
return `(${this.derivative(param)})`
Expand All @@ -127,15 +133,15 @@ export abstract class Op {
}

public derivInitializer(param: Param): string {
if (this.useTempVar()) {
return `${this.glslType()} ${this.derivRef(param)}=${this.derivative(param)};\n`
} else {
if (this.isConst(param) || !this.useTempVar()) {
return ''
} else {
return `${this.glslType()} ${this.derivRef(param)}=${this.derivative(param)};\n`
}
}

public isConst(): boolean {
return this.dependsOn.every((op) => op.isConst())
public isConst(param?: Param): boolean {
return this.dependsOn.every((op) => op.isConst(param))
}

public outputDependencies({ deps, derivDeps }: { deps: Set<Op>; derivDeps: Map<Param, Set<Op>> }): string {
Expand All @@ -162,6 +168,7 @@ export abstract class Op {
}

public outputDerivDependencies(param: Param, { deps, derivDeps }: { deps: Set<Op>; derivDeps: Map<Param, Set<Op>> }): string {
if (this.isConst()) return ''
let code = ''
for (const op of this.dependsOn) {
if (!deps.has(op)) {
Expand All @@ -170,7 +177,7 @@ export abstract class Op {
code += op.initializer()
}

if (!derivDeps.get(param)?.has(op)) {
if (!derivDeps.get(param)?.has(op) && !op.isConst(param)) {
const paramDerivDeps = derivDeps.get(param) ?? new Set<Op>()
paramDerivDeps.add(op)
derivDeps.set(param, paramDerivDeps)
Expand Down Expand Up @@ -318,7 +325,13 @@ export class Param extends OpLiteral {
}).join('') + this.id // Add id to ensure uniqueness
}

isConst() { return false }
isConst(param?: Param) {
if (param) {
return param !== this
} else {
return false
}
}
definition() { return this.name }
derivative(param: Param) {
if (param === this) {
Expand Down
30 changes: 27 additions & 3 deletions src/vecBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ export class VecParamElementRef extends Param {
this.ad.registerParam(this, this.name)
}

public isConst(param?: Param) {
if (param) {
return param !== this
} else {
return false
}
}

definition() { return `${this.dependsOn[0].ref()}.${this.prop}` }
derivative(param: Param) {
if (param === this) {
Expand Down Expand Up @@ -91,6 +99,14 @@ export abstract class VectorOp extends Op {
}
}

public glslType() {
return `vec${this.size()}`
}

zeroDerivative() {
return `${this.glslType()}(0.0)`
}

public u() { return this.x() }
public v() { return this.y() }
public r() { return this.x() }
Expand Down Expand Up @@ -363,13 +379,21 @@ export class VecParam extends VectorOp {
return `vec${this.size()}(${this.getElems().map((el) => el.derivRef(param)).join(',')})`
}

public isConst(param?: Param) {
if (param) {
return param !== this.x() && param !== this.y() && param !== this.z()
} else {
return false
}
}

public override initializer() { return '' }
public override ref() { return this.definition() }
public override derivInitializer(param: Param) {
if (this.useTempVar()) {
return `vec${this.size()} ${this.derivRef(param)}=${this.derivative(param)};\n`
} else {
if (this.isConst(param) || !this.useTempVar()) {
return ''
} else {
return `vec${this.size()} ${this.derivRef(param)}=${this.derivative(param)};\n`
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions test/simple-wiggle/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<!DOCTYPE html>
<html>
<head>
<title>glsl-autodiff test</title>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.3.1/p5.min.js"></script>
<script type="text/javascript" src="../../build/autodiff.js"></script>
<script type="text/javascript" src="test.js"></script>
</head>
<body>
</body>
</html>
137 changes: 137 additions & 0 deletions test/simple-wiggle/test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
const vert = `
attribute vec3 aPosition;
attribute vec3 aNormal;
attribute vec2 aTexCoord;
uniform mat4 uModelViewMatrix;
uniform mat4 uProjectionMatrix;
uniform mat3 uNormalMatrix;
uniform float time;
varying vec2 vTexCoord;
varying vec3 vNormal;
varying vec3 vPosition;
void main(void) {
vec4 objSpacePosition = vec4(aPosition, 1.0);
float origZ = objSpacePosition.z;
${AutoDiff.gen((ad) => {
const pos = ad.vec3Param('objSpacePosition')
const y = pos.y()
const time = ad.param('time')
offset = ad.vec3(time.mult(0.005).add(y.mult(2)).sin().mult(0.5), 0, 0)
offset.output('offset')
offset.adjustNormal(ad.vec3Param('aNormal'), pos).output('normal')
//offset.output('z')
//offset.outputDeriv('dzdx', x)
//offset.outputDeriv('dzdy', y)
}, { debug: true, maxDepthPerVariable: 8 })}
objSpacePosition.xyz += offset;
//vec3 slopeX = vec3(1.0, 0.0, dzdx);
//vec3 slopeY = vec3(0.0, 1.0, dzdy);
vec4 worldSpacePosition = uModelViewMatrix * objSpacePosition;
gl_Position = uProjectionMatrix * worldSpacePosition;
vTexCoord = aTexCoord;
vPosition = worldSpacePosition.xyz;
//vNormal = uNormalMatrix * aNormal;
//normal=cross(_glslad_v66,_glslad_v65);
//normal=_glslad_v66;
vNormal = uNormalMatrix * normal;
}
`
console.log(vert)

const frag = `
precision mediump float;
const int MAX_LIGHTS = 3;
varying vec2 vTexCoord;
varying vec3 vNormal;
varying vec3 vPosition;
uniform sampler2D img;
uniform int numLights;
uniform vec3 lightPositions[MAX_LIGHTS];
uniform vec3 lightColors[MAX_LIGHTS];
uniform float lightStrengths[MAX_LIGHTS];
uniform vec3 ambientLight;
uniform float materialShininess;
void main(void) {
vec3 materialColor = texture2D(img, vTexCoord).rgb;
vec3 normal = normalize(vNormal);
gl_FragColor = vec4(abs(normal), 1.); return;
//gl_FragColor = length(vNormal) * vec4(1.); return;
vec3 color = vec3(0.0, 0.0, 0.0);
for (int i = 0; i < MAX_LIGHTS; i++) {
if (i >= numLights) break;
vec3 lightPosition = lightPositions[i];
float distanceSquared = 0.0; /*0.00015*dot(
lightPosition - vPosition,
lightPosition - vPosition);*/
vec3 lightDir = normalize(lightPosition - vPosition);
float lambertian = max(dot(lightDir, normal), 0.0);
color += lambertian * materialColor * lightColors[i] *
lightStrengths[i] / (1.0 + distanceSquared);
vec3 viewDir = normalize(-vPosition);
float spec = pow(
max(dot(viewDir, reflect(-lightDir, normal)), 0.0),
materialShininess);
color += spec * lightStrengths[i] * lightColors[i] /
(1.0 + distanceSquared);
}
color += ambientLight * materialColor;
gl_FragColor = vec4(color, 1.0);
}
`

let distortShader
let texture
function setup() {
createCanvas(800, 600, WEBGL)
distortShader = createShader(vert, frag)
texture = createGraphics(500, 500)
}

const lights = [{
position: [200, 50, -100],
color: [1, 1, 1],
strength: 0.5,
},
{
position: [-200, -50, -100],
color: [1, 1, 1],
strength: 0.5,
},
];

function draw() {
texture.background(255, 0, 0)
texture.fill(255)
texture.noStroke()
texture.textSize(70)
texture.textAlign(CENTER, CENTER)
texture.text('hello, world', texture.width / 2, texture.height / 2)

background(0)

const shininess = 1000
const ambient = [0.2, 0.2, 0.2]

orbitControl()
noStroke()
shader(distortShader)
distortShader.setUniform('img', texture)
distortShader.setUniform('lightPositions', lights.map(l => l.position).flat())
distortShader.setUniform('lightColors', lights.map(l => l.color).flat())
distortShader.setUniform('lightStrengths', lights.map(l => l.strength).flat())
distortShader.setUniform('numLights', lights.length)
distortShader.setUniform('ambientLight', ambient)
distortShader.setUniform('materialShininess', shininess)
distortShader.setUniform('time', millis())
push()
sphere(200, 60, 30)
pop()
}

0 comments on commit a6037ca

Please sign in to comment.