Skip to content

Commit

Permalink
Update to how gradientCheckpointing works.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 21, 2024
1 parent 6a00340 commit 333d0bd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

git_repository(
name = "ccv",
commit = "e7c76392ec4ad529797484c6759fd2f8da974d99",
commit = "20d998dc3c7008060df6fdfa17e932c9dae93d31",
remote = "https://github.com/liuliu/ccv.git",
shallow_since = "1728857981 -0400",
shallow_since = "1729535387 -0400",
)

load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")
Expand Down
4 changes: 2 additions & 2 deletions deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def s4nnc_deps():
git_repository,
name = "ccv",
remote = "https://github.com/liuliu/ccv.git",
commit = "e7c76392ec4ad529797484c6759fd2f8da974d99",
shallow_since = "1728857981 -0400",
commit = "20d998dc3c7008060df6fdfa17e932c9dae93d31",
shallow_since = "1729535387 -0400",
)

_maybe(
Expand Down
17 changes: 14 additions & 3 deletions nnc/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,23 @@ public class Model {
* Whether to enable gradient checkpointing for this model. Once it is enabled, we will re-run
* the model forward pass again during backward pass. This is effective at reducing memory usage.
*/
public var gradientCheckpointing: Bool {
public var gradientCheckpointing: Bool? {
get {
ccv_cnnp_model_gradient_checkpointing(cModel) != 0
let value = ccv_cnnp_model_gradient_checkpointing(cModel)
if value == 1 {
return true
} else if value == -1 {
return false
} else {
return nil
}
}
set {
ccv_cnnp_model_set_gradient_checkpointing(cModel, newValue ? 1 : 0)
if let newValue = newValue {
ccv_cnnp_model_set_gradient_checkpointing(cModel, newValue ? 1 : -1)
} else {
ccv_cnnp_model_set_gradient_checkpointing(cModel, 0)
}
}
}

Expand Down

0 comments on commit 333d0bd

Please sign in to comment.