-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[orx-tensorflow] Add tensorflow support and utilities
- Loading branch information
Showing
8 changed files
with
347 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dependencies { | ||
runtimeOnly "org.tensorflow:tensorflow-core-api:$tensorflowVersion:linux-x86_64" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dependencies { | ||
runtimeOnly "org.tensorflow:tensorflow-core-api:$tensorflowVersion:macosx-x86_64" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dependencies { | ||
runtimeOnly "org.tensorflow:tensorflow-core-api:$tensorflowVersion:windows-x86_64" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
java { | ||
sourceCompatibility = JavaVersion.VERSION_1_8 | ||
targetCompatibility = JavaVersion.VERSION_1_8 | ||
} | ||
sourceSets { | ||
demo { | ||
java { | ||
srcDirs = ["src/demo/kotlin"] | ||
compileClasspath += main.getCompileClasspath() | ||
runtimeClasspath += main.getRuntimeClasspath() | ||
} | ||
} | ||
wrapgen { | ||
java { | ||
srcDirs = ["src/wrapgen/kotlin"] | ||
compileClasspath += main.getCompileClasspath() | ||
runtimeClasspath += main.getRuntimeClasspath() | ||
} | ||
} | ||
} | ||
|
||
compileWrapgenKotlin { | ||
sourceCompatibility = JavaVersion.VERSION_1_8 | ||
targetCompatibility = JavaVersion.VERSION_1_8 | ||
|
||
kotlinOptions { | ||
jvmTarget = "1.8" | ||
apiVersion = "1.4" | ||
languageVersion = "1.4" | ||
} | ||
} | ||
|
||
|
||
dependencies { | ||
implementation "com.google.code.gson:gson:$gsonVersion" | ||
demoImplementation("org.openrndr:openrndr-core:$openrndrVersion") | ||
|
||
|
||
demoRuntimeOnly(project(":orx-tensorflow-natives-$openrndrOS")) | ||
demoRuntimeOnly("org.openrndr:openrndr-gl3:$openrndrVersion") | ||
demoRuntimeOnly("org.openrndr:openrndr-gl3-natives-$openrndrOS:$openrndrVersion") | ||
demoRuntimeOnly("org.openrndr:openrndr-extensions:$openrndrVersion") | ||
demoImplementation("org.openrndr:openrndr-ffmpeg:$openrndrVersion") | ||
demoRuntimeOnly("org.openrndr:openrndr-ffmpeg-natives-$openrndrOS:$openrndrVersion") | ||
demoImplementation(project(":orx-fx")) | ||
demoImplementation(sourceSets.getByName("main").output) | ||
compile "org.tensorflow:tensorflow-core-api:$tensorflowVersion" | ||
|
||
// -- wrapgen | ||
wrapgenImplementation 'com.github.javaparser:javaparser-core:3.15.21' | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
package org.openrndr.extra.tensorflow | ||
|
||
|
||
import org.openrndr.draw.ColorBuffer | ||
import org.openrndr.draw.ColorFormat | ||
import org.openrndr.draw.ColorType | ||
import org.openrndr.draw.colorBuffer | ||
import org.openrndr.extra.tensorflow.arrays.* | ||
import org.tensorflow.Tensor | ||
import org.tensorflow.ndarray.StdArrays | ||
import org.tensorflow.ndarray.buffer.DataBuffers | ||
import org.tensorflow.types.* | ||
import org.tensorflow.types.family.TType | ||
import java.nio.ByteBuffer | ||
import java.nio.ByteOrder | ||
|
||
fun ColorBuffer.copyTo(tensor: Tensor<TFloat32>) { | ||
val buffer = ByteBuffer.allocateDirect(effectiveWidth * effectiveHeight * format.componentCount * 4) | ||
buffer.order(ByteOrder.nativeOrder()) | ||
this.read(buffer, targetType = ColorType.FLOAT32) | ||
buffer.rewind() | ||
val dataBuffer = DataBuffers.of(buffer.asFloatBuffer()) | ||
tensor.data().write(dataBuffer) | ||
} | ||
|
||
@JvmName("copyToTUint8") | ||
fun ColorBuffer.copyTo(tensor: Tensor<TUint8>) { | ||
val buffer = ByteBuffer.allocateDirect(effectiveWidth * effectiveHeight * format.componentCount) | ||
buffer.order(ByteOrder.nativeOrder()) | ||
this.read(buffer, targetType = ColorType.UINT8) | ||
buffer.rewind() | ||
val dataBuffer = DataBuffers.of(buffer) | ||
tensor.data().write(dataBuffer) | ||
} | ||
|
||
|
||
fun Tensor<TFloat32>.copyTo(colorBuffer: ColorBuffer) { | ||
val s = shape() | ||
require(s.numDimensions() == 2 || s.numDimensions() == 3) | ||
|
||
val components = when { | ||
s.numDimensions() == 3 -> s.size(2).toInt() | ||
s.numDimensions() == 4 -> s.size(3).toInt() | ||
else -> 1 | ||
} | ||
|
||
val format = when (components) { | ||
4 -> ColorFormat.RGBa | ||
3 -> ColorFormat.RGB | ||
2 -> ColorFormat.RG | ||
1 -> ColorFormat.R | ||
else -> error("only supports 1, 2, 3, or 4 components") | ||
} | ||
val buffer = ByteBuffer.allocateDirect(this.numBytes().toInt()) | ||
buffer.order(ByteOrder.nativeOrder()) | ||
val dataBuffer = DataBuffers.of(buffer.asFloatBuffer()) | ||
data().read(dataBuffer) | ||
buffer.rewind() | ||
colorBuffer.write(buffer, sourceFormat = format, sourceType = ColorType.FLOAT32) | ||
} | ||
|
||
|
||
fun <T : TType> Tensor<T>.summary() { | ||
println("type: ${this.dataType().name()}") | ||
println("shape: [${this.shape().asArray().joinToString(", ")}]") | ||
} | ||
|
||
fun Tensor<TInt32>.toIntArray(): IntArray { | ||
val elementCount = this.numBytes() / 4 | ||
val tensorData = data() | ||
val targetArray = IntArray(elementCount.toInt()) | ||
StdArrays.copyFrom(tensorData, targetArray) | ||
return targetArray | ||
} | ||
|
||
fun Tensor<TInt64>.toLongArray(): LongArray { | ||
val elementCount = this.numBytes() / 8 | ||
val tensorData = data() | ||
val targetArray = LongArray(elementCount.toInt()) | ||
StdArrays.copyFrom(tensorData, targetArray) | ||
return targetArray | ||
} | ||
|
||
fun Tensor<TUint8>.toByteArray(): ByteArray { | ||
val elementCount = this.numBytes() / 8 | ||
val tensorData = data() | ||
val targetArray = ByteArray(elementCount.toInt()) | ||
StdArrays.copyFrom(tensorData, targetArray) | ||
return targetArray | ||
} | ||
|
||
|
||
fun Tensor<TFloat32>.toFloatArray(): FloatArray { | ||
val elementCount = this.numBytes() / 4 | ||
val tensorData = data() | ||
val targetArray = FloatArray(elementCount.toInt()) | ||
StdArrays.copyFrom(tensorData, targetArray) | ||
return targetArray | ||
} | ||
|
||
fun Tensor<TFloat32>.toFloatArray2D(): FloatArray2D { | ||
val shape = this.shape() | ||
require(shape.numDimensions() == 2) { | ||
"tensor has ${shape.numDimensions()} dimensions, need 2" | ||
} | ||
val tensorData = data() | ||
val targetArray = floatArray2D(shape.size(0).toInt(), shape.size(1).toInt()) | ||
StdArrays.copyFrom(tensorData, targetArray) | ||
return targetArray | ||
} | ||
|
||
fun Tensor<TFloat32>.toFloatArray3D(): FloatArray3D { | ||
val shape = this.shape() | ||
require(shape.numDimensions() == 3) { | ||
"tensor has ${shape.numDimensions()} dimensions, need 3" | ||
} | ||
val tensorData = data() | ||
val targetArray = floatArray3D(shape.size(0).toInt(), shape.size(1).toInt(), shape.size(2).toInt()) | ||
StdArrays.copyFrom(tensorData, targetArray) | ||
return targetArray | ||
} | ||
|
||
fun Tensor<TFloat32>.toFloatArray4D(): FloatArray4D { | ||
val shape = this.shape() | ||
require(shape.numDimensions() == 4) { | ||
"tensor has ${shape.numDimensions()} dimensions, need 4" | ||
} | ||
val tensorData = data() | ||
val targetArray = floatArray4D(shape.size(0).toInt(), shape.size(1).toInt(), shape.size(2).toInt(), shape.size(3).toInt()) | ||
StdArrays.copyFrom(tensorData, targetArray) | ||
return targetArray | ||
} | ||
|
||
fun Tensor<TFloat64>.toDoubleArray(): DoubleArray { | ||
val elementCount = this.numBytes() / 8 | ||
val tensorData = data() | ||
val targetArray = DoubleArray(elementCount.toInt()) | ||
StdArrays.copyFrom(tensorData, targetArray) | ||
return targetArray | ||
} | ||
|
||
fun Tensor<TFloat32>.toColorBuffer(target: ColorBuffer? = null): ColorBuffer { | ||
val s = shape() | ||
require(s.numDimensions() == 2 || s.numDimensions() == 3) | ||
|
||
val width = (if (s.numDimensions() == 3) s.size(1) else s.size(0)).toInt() | ||
val height = (if (s.numDimensions() == 3) s.size(2) else s.size(1)).toInt() | ||
val components = if (s.numDimensions() == 3) s.size(0).toInt() else 1 | ||
|
||
val format = when (components) { | ||
4 -> ColorFormat.RGBa | ||
3 -> ColorFormat.RGB | ||
2 -> ColorFormat.RG | ||
1 -> ColorFormat.R | ||
else -> error("only supports 1, 2, 3, or 4 components") | ||
} | ||
|
||
val targetColorBuffer = target?: colorBuffer(width, height, format = format, type = ColorType.FLOAT32) | ||
val floatArray = toFloatArray() | ||
val bb = ByteBuffer.allocateDirect(width * height * components * 4) | ||
bb.order(ByteOrder.nativeOrder()) | ||
val fb = bb.asFloatBuffer() | ||
fb.put(floatArray) | ||
bb.rewind() | ||
targetColorBuffer.write(bb) | ||
return targetColorBuffer | ||
} | ||
|
||
|
||
@JvmName("toColorBufferTInt8") | ||
fun Tensor<TUint8>.toColorBuffer(target: ColorBuffer? = null): ColorBuffer { | ||
val s = shape() | ||
require(s.numDimensions() == 2 || s.numDimensions() == 3) | ||
|
||
val width = (if (s.numDimensions() == 3) s.size(1) else s.size(0)).toInt() | ||
val height = (if (s.numDimensions() == 3) s.size(2) else s.size(1)).toInt() | ||
val components = if (s.numDimensions() == 3) s.size(0).toInt() else 1 | ||
|
||
val format = when (components) { | ||
4 -> ColorFormat.RGBa | ||
3 -> ColorFormat.RGB | ||
2 -> ColorFormat.RG | ||
1 -> ColorFormat.R | ||
else -> error("only supports 1, 2, 3, or 4 components") | ||
} | ||
|
||
val byteArray = toByteArray() | ||
val targetColorBuffer = target?: colorBuffer(width, height, format = format, type = ColorType.UINT8) | ||
val bb = ByteBuffer.allocateDirect(width * height * components ) | ||
bb.order(ByteOrder.nativeOrder()) | ||
bb.put(byteArray) | ||
bb.rewind() | ||
targetColorBuffer.write(bb) | ||
return targetColorBuffer | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package org.openrndr.extra.tensorflow.arrays | ||
|
||
typealias FloatArray2D = Array<FloatArray> | ||
typealias FloatArray3D = Array<Array<FloatArray>> | ||
typealias FloatArray4D = Array<Array<Array<FloatArray>>> | ||
typealias FloatArray5D = Array<Array<Array<Array<FloatArray>>>> | ||
typealias FloatArray6D = Array<Array<Array<Array<Array<FloatArray>>>>> | ||
|
||
typealias IntArray2D = Array<IntArray> | ||
typealias IntArray3D = Array<Array<IntArray>> | ||
typealias IntArray4D = Array<Array<Array<IntArray>>> | ||
typealias IntArray5D = Array<Array<Array<Array<IntArray>>>> | ||
typealias IntArray6D = Array<Array<Array<Array<Array<IntArray>>>>> | ||
|
||
typealias BooleanArray2D = Array<BooleanArray> | ||
typealias BooleanArray3D = Array<Array<BooleanArray>> | ||
typealias BooleanArray4D = Array<Array<Array<BooleanArray>>> | ||
typealias BooleanArray5D = Array<Array<Array<Array<BooleanArray>>>> | ||
typealias BooleanArray6D = Array<Array<Array<Array<Array<BooleanArray>>>>> | ||
|
||
typealias LongArray2D = Array<LongArray> | ||
typealias LongArray3D = Array<Array<LongArray>> | ||
typealias LongArray4D = Array<Array<Array<LongArray>>> | ||
typealias LongArray5D = Array<Array<Array<Array<LongArray>>>> | ||
typealias LongArray6D = Array<Array<Array<Array<Array<LongArray>>>>> | ||
|
||
typealias ByteArray2D = Array<ByteArray> | ||
typealias ByteArray3D = Array<Array<ByteArray>> | ||
typealias ByteArray4D = Array<Array<Array<ByteArray>>> | ||
typealias ByteArray5D = Array<Array<Array<Array<ByteArray>>>> | ||
typealias ByteArray6D = Array<Array<Array<Array<Array<ByteArray>>>>> | ||
|
||
typealias DoubleArray2D = Array<DoubleArray> | ||
typealias DoubleArray3D = Array<Array<DoubleArray>> | ||
typealias DoubleArray4D = Array<Array<Array<DoubleArray>>> | ||
typealias DoubleArray5D = Array<Array<Array<Array<DoubleArray>>>> | ||
typealias DoubleArray6D = Array<Array<Array<Array<Array<DoubleArray>>>>> | ||
|
||
fun floatArray2D(y: Int, x: Int): FloatArray2D = Array(y) { FloatArray(x) } | ||
fun floatArray3D(z: Int, y: Int, x: Int): FloatArray3D = Array(z) { Array(y) { FloatArray(x) } } | ||
fun floatArray4D(w: Int, z: Int, y: Int, x: Int): FloatArray4D = Array(w) { Array(z) { Array(y) { FloatArray(x) } } } | ||
|
||
fun doubleArray2D(y: Int, x: Int): DoubleArray2D = Array(y) { DoubleArray(x) } | ||
fun doubleArray3D(z: Int, y: Int, x: Int): DoubleArray3D = Array(z) { Array(y) { DoubleArray(x) } } | ||
fun doubleArray4D(w: Int, z: Int, y: Int, x: Int): DoubleArray4D = Array(w) { Array(z) { Array(y) { DoubleArray(x) } } } | ||
|
||
fun intArray2D(y: Int, x: Int): IntArray2D = Array(y) { IntArray(x) } | ||
fun intArray3D(z: Int, y: Int, x: Int): IntArray3D = Array(z) { Array(y) { IntArray(x) } } | ||
fun intArray4D(w: Int, z: Int, y: Int, x: Int): IntArray4D = Array(w) { Array(z) { Array(y) { IntArray(x) } } } | ||
|
||
fun longArray2D(y: Int, x: Int): LongArray2D = Array(y) { LongArray(x) } | ||
fun longArray3D(z: Int, y: Int, x: Int): LongArray3D = Array(z) { Array(y) { LongArray(x) } } | ||
fun longArray4D(w: Int, z: Int, y: Int, x: Int): LongArray4D = Array(w) { Array(z) { Array(y) { LongArray(x) } } } | ||
|
||
fun byteArray2D(y: Int, x: Int): ByteArray2D = Array(y) { ByteArray(x) } | ||
fun byteArray3D(z: Int, y: Int, x: Int): ByteArray3D = Array(z) { Array(y) { ByteArray(x) } } | ||
fun byteArray4D(w: Int, z: Int, y: Int, x: Int): ByteArray4D = Array(w) { Array(z) { Array(y) { ByteArray(x) } } } | ||
|
||
fun booleanArray2D(y: Int, x: Int): BooleanArray2D = Array(y) { BooleanArray(x) } | ||
fun booleanArray3D(z: Int, y: Int, x: Int): BooleanArray3D = Array(z) { Array(y) { BooleanArray(x) } } | ||
fun booleanArray4D(w: Int, z: Int, y: Int, x: Int): BooleanArray4D = Array(w) { Array(z) { Array(y) { BooleanArray(x) } } } | ||
|
||
operator fun FloatArray2D.get(y: Int, x: Int) = this[y][x] | ||
operator fun FloatArray3D.get(z: Int, y: Int, x: Int) = this[z][y][x] | ||
operator fun FloatArray4D.get(w: Int, z: Int, y: Int, x: Int) = this[w][z][y][x] | ||
|
||
operator fun DoubleArray2D.get(y: Int, x: Int) = this[y][x] | ||
operator fun DoubleArray3D.get(z: Int, y: Int, x: Int) = this[z][y][x] | ||
operator fun DoubleArray4D.get(w: Int, z: Int, y: Int, x: Int) = this[w][z][y][x] | ||
|
||
operator fun IntArray2D.get(y: Int, x: Int) = this[y][x] | ||
operator fun IntArray3D.get(z: Int, y: Int, x: Int) = this[z][y][x] | ||
operator fun IntArray4D.get(w: Int, z: Int, y: Int, x: Int) = this[w][z][y][x] | ||
|
||
operator fun LongArray2D.get(y: Int, x: Int) = this[y][x] | ||
operator fun LongArray3D.get(z: Int, y: Int, x: Int) = this[z][y][x] | ||
operator fun LongArray4D.get(w: Int, z: Int, y: Int, x: Int) = this[w][z][y][x] | ||
|
||
operator fun ByteArray2D.get(y: Int, x: Int) = this[y][x] | ||
operator fun ByteArray3D.get(z: Int, y: Int, x: Int) = this[z][y][x] | ||
operator fun ByteArray4D.get(w: Int, z: Int, y: Int, x: Int) = this[w][z][y][x] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters