diff --git a/package.json b/package.json index 9ce2b86..6394c02 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@davepagurek/glsl-autodiff", - "version": "0.0.19", + "version": "0.0.20", "main": "build/autodiff.js", "author": "Dave Pagurek ", "license": "MIT", diff --git a/src/arithmetic.ts b/src/arithmetic.ts index 4333cc6..2d80afa 100644 --- a/src/arithmetic.ts +++ b/src/arithmetic.ts @@ -24,7 +24,15 @@ export class Mult extends Op { } derivative(param: Param) { const [f, g] = this.dependsOn - return `${f.ref()}*${g.derivRef(param)}+${g.ref()}*${f.derivRef(param)}` + const fIsConst = f.isConst(param) + const gIsConst = g.isConst(param) + if (fIsConst && !gIsConst) { + return `${f.ref()}*${g.derivRef(param)}` + } else if (!fIsConst && gIsConst) { + return `${g.ref()}*${f.derivRef(param)}` + } else { + return `${f.ref()}*${g.derivRef(param)}+${g.ref()}*${f.derivRef(param)}` + } } } diff --git a/src/base.ts b/src/base.ts index 62badc2..370de73 100644 --- a/src/base.ts +++ b/src/base.ts @@ -105,8 +105,14 @@ export abstract class Op { } } + public zeroDerivative() { + return '0.0' + } + public derivRef(param: Param): string { - if (this.useTempVar()) { + if (this.isConst(param)) { + return this.zeroDerivative() + } else if (this.useTempVar()) { return `_glslad_dv${this.id}_d${param.safeName()}` } else { return `(${this.derivative(param)})` @@ -127,15 +133,15 @@ export abstract class Op { } public derivInitializer(param: Param): string { - if (this.useTempVar()) { - return `${this.glslType()} ${this.derivRef(param)}=${this.derivative(param)};\n` - } else { + if (this.isConst(param) || !this.useTempVar()) { return '' + } else { + return `${this.glslType()} ${this.derivRef(param)}=${this.derivative(param)};\n` } } - public isConst(): boolean { - return this.dependsOn.every((op) => op.isConst()) + public isConst(param?: Param): boolean { + return this.dependsOn.every((op) => op.isConst(param)) } public outputDependencies({ deps, derivDeps }: { deps: Set; derivDeps: Map> }): string { @@ -162,6 +168,7 @@ export abstract class Op { } public outputDerivDependencies(param: Param, { deps, derivDeps }: { deps: Set; derivDeps: Map> }): string { + if (this.isConst()) return '' let code = '' for (const op of this.dependsOn) { if (!deps.has(op)) { @@ -170,7 +177,7 @@ export abstract class Op { code += op.initializer() } - if (!derivDeps.get(param)?.has(op)) { + if (!derivDeps.get(param)?.has(op) && !op.isConst(param)) { const paramDerivDeps = derivDeps.get(param) ?? new Set() paramDerivDeps.add(op) derivDeps.set(param, paramDerivDeps) @@ -318,7 +325,13 @@ export class Param extends OpLiteral { }).join('') + this.id // Add id to ensure uniqueness } - isConst() { return false } + isConst(param?: Param) { + if (param) { + return param !== this + } else { + return false + } + } definition() { return this.name } derivative(param: Param) { if (param === this) { diff --git a/src/vecBase.ts b/src/vecBase.ts index ac22691..ef6c29d 100644 --- a/src/vecBase.ts +++ b/src/vecBase.ts @@ -33,6 +33,14 @@ export class VecParamElementRef extends Param { this.ad.registerParam(this, this.name) } + public isConst(param?: Param) { + if (param) { + return param !== this + } else { + return false + } + } + definition() { return `${this.dependsOn[0].ref()}.${this.prop}` } derivative(param: Param) { if (param === this) { @@ -91,6 +99,14 @@ export abstract class VectorOp extends Op { } } + public glslType() { + return `vec${this.size()}` + } + + zeroDerivative() { + return `${this.glslType()}(0.0)` + } + public u() { return this.x() } public v() { return this.y() } public r() { return this.x() } @@ -363,13 +379,21 @@ export class VecParam extends VectorOp { return `vec${this.size()}(${this.getElems().map((el) => el.derivRef(param)).join(',')})` } + public isConst(param?: Param) { + if (param) { + return param !== this.x() && param !== this.y() && param !== this.z() + } else { + return false + } + } + public override initializer() { return '' } public override ref() { return this.definition() } public override derivInitializer(param: Param) { - if (this.useTempVar()) { - return `vec${this.size()} ${this.derivRef(param)}=${this.derivative(param)};\n` - } else { + if (this.isConst(param) || !this.useTempVar()) { return '' + } else { + return `vec${this.size()} ${this.derivRef(param)}=${this.derivative(param)};\n` } } } diff --git a/test/simple-wiggle/index.html b/test/simple-wiggle/index.html new file mode 100644 index 0000000..d572cda --- /dev/null +++ b/test/simple-wiggle/index.html @@ -0,0 +1,11 @@ + + + + glsl-autodiff test + + + + + + + diff --git a/test/simple-wiggle/test.js b/test/simple-wiggle/test.js new file mode 100644 index 0000000..04fd056 --- /dev/null +++ b/test/simple-wiggle/test.js @@ -0,0 +1,137 @@ +const vert = ` +attribute vec3 aPosition; +attribute vec3 aNormal; +attribute vec2 aTexCoord; + +uniform mat4 uModelViewMatrix; +uniform mat4 uProjectionMatrix; +uniform mat3 uNormalMatrix; + +uniform float time; + +varying vec2 vTexCoord; +varying vec3 vNormal; +varying vec3 vPosition; + +void main(void) { + vec4 objSpacePosition = vec4(aPosition, 1.0); + float origZ = objSpacePosition.z; + ${AutoDiff.gen((ad) => { + const pos = ad.vec3Param('objSpacePosition') + const y = pos.y() + const time = ad.param('time') + + offset = ad.vec3(time.mult(0.005).add(y.mult(2)).sin().mult(0.5), 0, 0) + offset.output('offset') + offset.adjustNormal(ad.vec3Param('aNormal'), pos).output('normal') + //offset.output('z') + //offset.outputDeriv('dzdx', x) + //offset.outputDeriv('dzdy', y) + }, { debug: true, maxDepthPerVariable: 8 })} + objSpacePosition.xyz += offset; + //vec3 slopeX = vec3(1.0, 0.0, dzdx); + //vec3 slopeY = vec3(0.0, 1.0, dzdy); + vec4 worldSpacePosition = uModelViewMatrix * objSpacePosition; + gl_Position = uProjectionMatrix * worldSpacePosition; + vTexCoord = aTexCoord; + vPosition = worldSpacePosition.xyz; + //vNormal = uNormalMatrix * aNormal; + //normal=cross(_glslad_v66,_glslad_v65); + //normal=_glslad_v66; + vNormal = uNormalMatrix * normal; +} +` +console.log(vert) + +const frag = ` +precision mediump float; +const int MAX_LIGHTS = 3; + +varying vec2 vTexCoord; +varying vec3 vNormal; +varying vec3 vPosition; + +uniform sampler2D img; +uniform int numLights; +uniform vec3 lightPositions[MAX_LIGHTS]; +uniform vec3 lightColors[MAX_LIGHTS]; +uniform float lightStrengths[MAX_LIGHTS]; +uniform vec3 ambientLight; +uniform float materialShininess; + +void main(void) { + vec3 materialColor = texture2D(img, vTexCoord).rgb; + vec3 normal = normalize(vNormal); + gl_FragColor = vec4(abs(normal), 1.); return; + //gl_FragColor = length(vNormal) * vec4(1.); return; + vec3 color = vec3(0.0, 0.0, 0.0); + for (int i = 0; i < MAX_LIGHTS; i++) { + if (i >= numLights) break; + vec3 lightPosition = lightPositions[i]; + float distanceSquared = 0.0; /*0.00015*dot( + lightPosition - vPosition, + lightPosition - vPosition);*/ + vec3 lightDir = normalize(lightPosition - vPosition); + float lambertian = max(dot(lightDir, normal), 0.0); + color += lambertian * materialColor * lightColors[i] * + lightStrengths[i] / (1.0 + distanceSquared); + vec3 viewDir = normalize(-vPosition); + float spec = pow( + max(dot(viewDir, reflect(-lightDir, normal)), 0.0), + materialShininess); + color += spec * lightStrengths[i] * lightColors[i] / + (1.0 + distanceSquared); + } + color += ambientLight * materialColor; + gl_FragColor = vec4(color, 1.0); +} +` + +let distortShader +let texture +function setup() { + createCanvas(800, 600, WEBGL) + distortShader = createShader(vert, frag) + texture = createGraphics(500, 500) +} + +const lights = [{ + position: [200, 50, -100], + color: [1, 1, 1], + strength: 0.5, + }, + { + position: [-200, -50, -100], + color: [1, 1, 1], + strength: 0.5, + }, +]; + +function draw() { + texture.background(255, 0, 0) + texture.fill(255) + texture.noStroke() + texture.textSize(70) + texture.textAlign(CENTER, CENTER) + texture.text('hello, world', texture.width / 2, texture.height / 2) + + background(0) + + const shininess = 1000 + const ambient = [0.2, 0.2, 0.2] + + orbitControl() + noStroke() + shader(distortShader) + distortShader.setUniform('img', texture) + distortShader.setUniform('lightPositions', lights.map(l => l.position).flat()) + distortShader.setUniform('lightColors', lights.map(l => l.color).flat()) + distortShader.setUniform('lightStrengths', lights.map(l => l.strength).flat()) + distortShader.setUniform('numLights', lights.length) + distortShader.setUniform('ambientLight', ambient) + distortShader.setUniform('materialShininess', shininess) + distortShader.setUniform('time', millis()) + push() + sphere(200, 60, 30) + pop() +}