Skip to content

Commit

Permalink
add Adam and AdamW optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
Vindaar committed Apr 26, 2023
1 parent 4de3bdd commit 4d1d8c7
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions flambeau/raw/bindings/optimizers.nim
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ type
{.pure, bycopy, importcpp: "torch::optim::SGD".}
= object of Optimizer

Adam*
{.pure, bycopy, importcpp: "torch::optim::Adam".}
= object of Optimizer

AdamW*
{.pure, bycopy, importcpp: "torch::optim::AdamW".}
= object of Optimizer


func step*(optim: var Optimizer){.importcpp: "#.step()".}
func zero_grad*(optim: var Optimizer){.importcpp: "#.zero_grad()".}

Expand All @@ -57,6 +66,21 @@ func init*(
): Optim
{.constructor, importcpp: "torch::optim::SGD(@)".}

func init*(
Optim: type Adam,
params: CppVector[RawTensor],
learning_rate: float64
): Optim
{.constructor, importcpp:"torch::optim::Adam(@)".}

func init*(
Optim: type AdamW,
params: CppVector[RawTensor],
learning_rate: float64
): Optim
{.constructor, importcpp:"torch::optim::AdamW(@)".}


# SGD-specific
# -----------------------------------------------------------
type
Expand All @@ -76,3 +100,43 @@ func init*(
options: SGDOptions
): T
{.constructor, importcpp: "torch::optim::SGD(@)".}

# Adam-specific
# -----------------------------------------------------------
type
AdamOptions*
{.pure, bycopy, importcpp: "torch::optim::AdamOptions".}
= object of OptimizerOptions

func init*(T: type AdamOptions, learning_rate: float64): T {.constructor, importcpp: "torch::optim::AdamOptions(@)".}
func betas*(opt: AdamOptions, beta: float64): AdamOptions {.importcpp: "#.betas(#)".}
func eps*(opt: AdamOptions, eps: float64): AdamOptions {.importcpp: "#.eps(#)".}
func weight_decay*(opt: AdamOptions, weight_decay: float64): AdamOptions {.importcpp: "#.weight_decay(#)".}
func amsgrad*(opt: AdamOptions, useAmsGrad: bool): AdamOptions {.importcpp: "#.amsgrad(#)".}

func init*(
T: type Adam,
params: CppVector[RawTensor],
options: AdamOptions
): T
{.constructor, noInit, importcpp: "torch::optim::Adam(@)".}

# AdamW-specific
# -----------------------------------------------------------
type
AdamWOptions*
{.pure, bycopy, importcpp: "torch::optim::AdamWOptions".}
= object of OptimizerOptions

func init*(T: type AdamWOptions, learning_rate: float64): T {.constructor, importcpp: "torch::optim::AdamWOptions(@)".}
func betas*(opt: AdamWOptions, beta: float64): AdamWOptions {.importcpp: "#.betas(#)".}
func eps*(opt: AdamWOptions, eps: float64): AdamWOptions {.importcpp: "#.eps(#)".}
func weight_decay*(opt: AdamWOptions, weight_decay: float64): AdamWOptions {.importcpp: "#.weight_decay(#)".}
func amsgrad*(opt: AdamWOptions, useAmsGrad: bool): AdamWOptions {.importcpp: "#.amsgrad(#)".}

func init*(
T: type AdamW,
params: CppVector[RawTensor],
options: AdamWOptions
): T
{.constructor, noInit, importcpp: "torch::optim::AdamW(@)".}

0 comments on commit 4d1d8c7

Please sign in to comment.