Skip to content

Commit

Permalink
Add amsgrad parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 2, 2023
1 parent 3c2a6d1 commit 1f8db20
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

git_repository(
name = "ccv",
commit = "0ce86f0b50b78bcad82a5b0bec7a4b60f1a28bcc",
commit = "0546f0dcbf0711928b7573012ef4d64479d18c11",
remote = "https://github.com/liuliu/ccv.git",
shallow_since = "1695851883 -0400",
shallow_since = "1696260490 -0400",
)

load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")
Expand Down
4 changes: 2 additions & 2 deletions deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def s4nnc_deps():
git_repository,
name = "ccv",
remote = "https://github.com/liuliu/ccv.git",
commit = "0ce86f0b50b78bcad82a5b0bec7a4b60f1a28bcc",
shallow_since = "1695851883 -0400",
commit = "0546f0dcbf0711928b7573012ef4d64479d18c11",
shallow_since = "1696260490 -0400",
)

_maybe(
Expand Down
10 changes: 8 additions & 2 deletions nnc/OptimizerAddons.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public struct AdamOptimizer: Optimizer, OptimizerAddons, OptimizerTrackSteps {
public var decay: Float
public var epsilon: Float
public var scale: Float
public var amsgrad: Bool
public var parameters = [DynamicGraph_AnyParameters]() {
willSet {
for var parameter in parameters.compactMap({ $0 as? DynamicGraph_Any }) {
Expand All @@ -90,19 +91,21 @@ public struct AdamOptimizer: Optimizer, OptimizerAddons, OptimizerTrackSteps {
params.adam.beta2 = betas.1
params.adam.decay = decay
params.adam.epsilon = epsilon
params.adam.amsgrad = amsgrad ? 1 : 0
return ccv_nnc_cmd(CCV_NNC_ADAM_FORWARD, nil, params, 0)
}

public init(
_ graph: DynamicGraph, rate: Float = 0.001, step: Int = 1, betas: (Float, Float) = (0.9, 0.999),
decay: Float = 0, epsilon: Float = 1e-8
decay: Float = 0, epsilon: Float = 1e-8, amsgrad: Bool = false
) {
self.graph = graph
self.step = step
self.rate = rate
self.betas = betas
self.decay = decay
self.epsilon = epsilon
self.amsgrad = amsgrad
scale = 1
}

Expand Down Expand Up @@ -181,6 +184,7 @@ public struct AdamWOptimizer: Optimizer, OptimizerAddons, OptimizerTrackSteps {
public var decay: Float
public var epsilon: Float
public var scale: Float
public var amsgrad: Bool
public var parameters = [DynamicGraph_AnyParameters]() {
willSet {
for var parameter in parameters.compactMap({ $0 as? DynamicGraph_Any }) {
Expand All @@ -206,19 +210,21 @@ public struct AdamWOptimizer: Optimizer, OptimizerAddons, OptimizerTrackSteps {
params.adam.beta2 = betas.1
params.adam.decay = decay
params.adam.epsilon = epsilon
params.adam.amsgrad = amsgrad ? 1 : 0
return ccv_nnc_cmd(CCV_NNC_ADAMW_FORWARD, nil, params, 0)
}

public init(
_ graph: DynamicGraph, rate: Float = 0.001, step: Int = 1, betas: (Float, Float) = (0.9, 0.999),
decay: Float = 0, epsilon: Float = 1e-8
decay: Float = 0, epsilon: Float = 1e-8, amsgrad: Bool = false
) {
self.graph = graph
self.step = step
self.rate = rate
self.betas = betas
self.decay = decay
self.epsilon = epsilon
self.amsgrad = amsgrad
scale = 1
}

Expand Down

0 comments on commit 1f8db20

Please sign in to comment.