Skip to content

Commit

Permalink
Add AnyModel.
Browse files Browse the repository at this point in the history
This helps to smooth over differences between Model and AnyModelBuilder,
such that applications can pass AnyModel around if they don't care
about execution (such as compile(inputs:) call or (inputs:) call).
  • Loading branch information
liuliu committed Nov 27, 2024
1 parent 7ca5bc5 commit aab8d3c
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 4 deletions.
36 changes: 36 additions & 0 deletions nnc/AnyModel.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

public protocol AnyModel {
/**
* Whether the existing model is for testing or training.
*/
var testing: Bool { get set }
/**
* Whether to enable memory reduction for this model. The current supported memory reduction
* technique is to redo datatype conversion during backward pass if needed.
*/
var memoryReduction: Bool { get set }
/**
* Specify the maximum number of streams we need to allocate to run this model.
*/
var maxConcurrency: StreamContext.Concurrency { get set }
/**
* Abstract representation of the stateful components from the model builder.
*/
var parameters: Model.Parameters { get }
/**
* Shortcut for weight parameter.
*/
var weight: Model.Parameters { get }
/**
* Shortcut for bias parameter.
*/
var bias: Model.Parameters { get }
/**
* Broadly speaking, you can have two types of parameters, weight and bias.
* You can get them in abstract fashion with this method.
*
* - Parameter type: Whether it is weight or bias.
* - Returns: An abstract representation of parameters.
*/
func parameters(for type: Model.ParametersType) -> Model.Parameters
}
1 change: 1 addition & 0 deletions nnc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cc_library(
swift_library(
name = "nnc",
srcs = [
"AnyModel.swift",
"AutoGrad.swift",
"DataFrame.swift",
"DataFrameAddons.swift",
Expand Down
2 changes: 1 addition & 1 deletion nnc/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public protocol ModelIOConvertible {

/// A model is a base class for stateful operations on a dynamic graph. It can be
/// use to construct computations statically, thus, more efficient.
public class Model {
public class Model: AnyModel {

/**
* A IO class represent the abstract input / output for a model. It can correspond
Expand Down
7 changes: 5 additions & 2 deletions nnc/ModelBuilder.swift
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import C_nnc

/// A type-erased model builder.
public class AnyModelBuilder {
public class AnyModelBuilder: AnyModel {

public var testing: Bool = false
public var testing: Bool {
get { model!.testing }
set { model!.testing = newValue }
}

var model: Model? = nil
var t: Any? = nil
Expand Down
64 changes: 63 additions & 1 deletion nnc/Store.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4578,6 +4578,7 @@ extension DynamicGraph {
* - codec: The codec for potential encoded parameters.
* - reader: You can customize your reader to load parameter with a different name etc.
*/
@inlinable
public func read(
_ key: String, model: Model, codec: Codec = [],
reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil
Expand Down Expand Up @@ -4609,13 +4610,52 @@ extension DynamicGraph {
* - codec: The codec for potential encoded parameters.
* - reader: You can customize your reader to load parameter with a different name etc.
*/
@inlinable
public func read(
_ key: String, model: AnyModelBuilder, codec: Codec = [],
reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil
) {
try? read(key, model: model, strict: false, codec: codec, reader: reader)
}

/**
* Read parameters into a given model.
*
* - Parameters:
* - key: The key corresponding to a particular model.
* - model: The model to be initialized with parameters from a given key.
* - strict: When this is true, will throw error if any parameters are missing.
* - codec: The codec for potential encoded parameters.
* - reader: You can customize your reader to load parameter with a different name etc.
*/
@inlinable
public func read(
_ key: String, model: AnyModel, strict: Bool, codec: Codec = [],
reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil
) throws {
switch model {
case let model as Model:
try read(key, model: model, strict: strict, codec: codec, reader: reader)
case let model as AnyModelBuilder:
try read(key, model: model, strict: strict, codec: codec, reader: reader)
default:
fatalError("Unrecognized model \(model)")
}
}
/**
* Read parameters into a given model.
*
* - Parameters:
* - key: The key corresponding to a particular model.
* - model: The model to be initialized with parameters from a given key.
* - codec: The codec for potential encoded parameters.
* - reader: You can customize your reader to load parameter with a different name etc.
*/
public func read(
_ key: String, model: AnyModel, codec: Codec = [],
reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil
) {
try? read(key, model: model, strict: false, codec: codec, reader: reader)
}
/**
* Write a tensor to the store.
*
Expand Down Expand Up @@ -4748,6 +4788,28 @@ extension DynamicGraph {
) {
write(key, model: model.model!, codec: codec, writer: writer)
}
/**
* Write a model to the store.
*
* - Parameters:
* - key: The key corresponding to a particular model.
* - model: The model where its parameters to be persisted.
* - writer: You can customize your writer to writer parameter with a different name or skip entirely.
*/
@inlinable
public func write(
_ key: String, model: AnyModel, codec: Codec = [],
writer: ((String, NNC.AnyTensor) -> ModelWriterResult)? = nil
) {
switch model {
case let model as Model:
write(key, model: model, codec: codec, writer: writer)
case let model as AnyModelBuilder:
write(key, model: model, codec: codec, writer: writer)
default:
fatalError("Unrecognized model \(model)")
}
}

init(_ store: _Store, graph: DynamicGraph) {
self.store = store
Expand Down

0 comments on commit aab8d3c

Please sign in to comment.