From 4d1d8c77cfb78592b14b663026b7eb2d3637c1db Mon Sep 17 00:00:00 2001 From: Vindaar Date: Wed, 26 Apr 2023 10:39:53 +0200 Subject: [PATCH] add Adam and AdamW optimizers --- flambeau/raw/bindings/optimizers.nim | 64 ++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/flambeau/raw/bindings/optimizers.nim b/flambeau/raw/bindings/optimizers.nim index 3c049c7..ea2e767 100644 --- a/flambeau/raw/bindings/optimizers.nim +++ b/flambeau/raw/bindings/optimizers.nim @@ -47,6 +47,15 @@ type {.pure, bycopy, importcpp: "torch::optim::SGD".} = object of Optimizer + Adam* + {.pure, bycopy, importcpp: "torch::optim::Adam".} + = object of Optimizer + + AdamW* + {.pure, bycopy, importcpp: "torch::optim::AdamW".} + = object of Optimizer + + func step*(optim: var Optimizer){.importcpp: "#.step()".} func zero_grad*(optim: var Optimizer){.importcpp: "#.zero_grad()".} @@ -57,6 +66,21 @@ func init*( ): Optim {.constructor, importcpp: "torch::optim::SGD(@)".} +func init*( + Optim: type Adam, + params: CppVector[RawTensor], + learning_rate: float64 + ): Optim + {.constructor, importcpp:"torch::optim::Adam(@)".} + +func init*( + Optim: type AdamW, + params: CppVector[RawTensor], + learning_rate: float64 + ): Optim + {.constructor, importcpp:"torch::optim::AdamW(@)".} + + # SGD-specific # ----------------------------------------------------------- type @@ -76,3 +100,43 @@ func init*( options: SGDOptions ): T {.constructor, importcpp: "torch::optim::SGD(@)".} + +# Adam-specific +# ----------------------------------------------------------- +type + AdamOptions* + {.pure, bycopy, importcpp: "torch::optim::AdamOptions".} + = object of OptimizerOptions + +func init*(T: type AdamOptions, learning_rate: float64): T {.constructor, importcpp: "torch::optim::AdamOptions(@)".} +func betas*(opt: AdamOptions, beta: float64): AdamOptions {.importcpp: "#.betas(#)".} +func eps*(opt: AdamOptions, eps: float64): AdamOptions {.importcpp: "#.eps(#)".} +func weight_decay*(opt: AdamOptions, weight_decay: float64): AdamOptions {.importcpp: "#.weight_decay(#)".} +func amsgrad*(opt: AdamOptions, useAmsGrad: bool): AdamOptions {.importcpp: "#.amsgrad(#)".} + +func init*( + T: type Adam, + params: CppVector[RawTensor], + options: AdamOptions + ): T + {.constructor, noInit, importcpp: "torch::optim::Adam(@)".} + +# AdamW-specific +# ----------------------------------------------------------- +type + AdamWOptions* + {.pure, bycopy, importcpp: "torch::optim::AdamWOptions".} + = object of OptimizerOptions + +func init*(T: type AdamWOptions, learning_rate: float64): T {.constructor, importcpp: "torch::optim::AdamWOptions(@)".} +func betas*(opt: AdamWOptions, beta: float64): AdamWOptions {.importcpp: "#.betas(#)".} +func eps*(opt: AdamWOptions, eps: float64): AdamWOptions {.importcpp: "#.eps(#)".} +func weight_decay*(opt: AdamWOptions, weight_decay: float64): AdamWOptions {.importcpp: "#.weight_decay(#)".} +func amsgrad*(opt: AdamWOptions, useAmsGrad: bool): AdamWOptions {.importcpp: "#.amsgrad(#)".} + +func init*( + T: type AdamW, + params: CppVector[RawTensor], + options: AdamWOptions + ): T + {.constructor, noInit, importcpp: "torch::optim::AdamW(@)".}