Skip to content

Commit

Permalink
Add backpropagation implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
rounakdatta committed Apr 10, 2024
1 parent fad34ca commit 2b9c099
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
50 changes: 37 additions & 13 deletions lib/neuron.ml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
module Neuron = struct
type t = {
mutable data : float;
data : float;
mutable grad : float;
mutable backward : unit -> unit;

(* capturing the operator would be useful later when we add some viz *)
op : string;
prev : t list;
operator : string;
dependents : t list;
}

let create data operator = {
data;
let create dt op deps = {
data = dt;
grad = 0.;
backward = (fun () -> ());
op = operator; prev = [];
operator = op;
dependents = deps;
}

let add base partner =
let resultant = create (base.data +. partner.data) "+" in
let resultant = create (base.data +. partner.data) "+" [base; partner] in

resultant.backward <- (fun () ->
base.grad <- base.grad +. resultant.grad;
Expand All @@ -26,7 +26,7 @@ module Neuron = struct
resultant

let mul base partner =
let resultant = create (base.data *. partner.data) "*" in
let resultant = create (base.data *. partner.data) "*" [base; partner] in

resultant.backward <- (fun () ->
base.grad <- base.grad +. partner.data *. resultant.grad;
Expand All @@ -35,16 +35,40 @@ module Neuron = struct
resultant

let exp base exponent =
let resultant = create (base.data ** exponent) "**" in
let resultant = create (base.data ** exponent) "**" [base] in

resultant.backward <- (fun () ->
base.grad <- base.grad +. exponent *. (base.data ** (exponent -. 1.)) *. resultant.grad;
)
);
resultant

let relu base =
let resultant = create (max 0. base.data) "relu" in
let resultant = create (max 0. base.data) "relu" [base] in

resultant.backward <- (fun () ->
base.grad <- base.grad +. (if base.data > 0. then resultant.grad else 0.);
)
);
resultant

let backpropagate base =
(* we topologically sort all the connected nodes from the base node *)
let rec sort_topologically visited resultant candidate =
if not (List.mem candidate visited) then
let visited = candidate :: visited in
let visited, resultant =
List.fold_left (fun (visited, resultant) dependent ->
sort_topologically visited resultant dependent
) (visited, resultant) candidate.dependents
in
(visited, candidate :: resultant)
else
(visited, resultant)
in
let _, resultant = sort_topologically [] [] base in

base.grad <- 1.0;

(* now we go backward from end-mouth of the graph to the connected start nodes *)
(* and propagate the gradient changes *)
List.iter (fun v -> v.backward ()) (List.rev resultant)
end
6 changes: 3 additions & 3 deletions lib/neuron.mli
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module Neuron : sig
type t

(* Constructor; constructs a unit neuron of a value and an operator. *)
val create : float -> string -> t
val create : float -> string -> t list -> t

(* Handles the gradient flows in addition operation. *)
val add : t -> t -> t
Expand All @@ -12,11 +12,11 @@ module Neuron : sig

(* Handles the gradient flows in exponent / power operation. *)
(* second argument is the exponent. *)
val exp : t -> int -> t
val exp : t -> float -> t

(* Handles the gradient flows in ReLU operation. *)
val relu : t -> t

(* Handles backpropagation of the gradients. *)
(* Handles backpropagation of the gradients for all the nodes connected to the specified base node. *)
val backpropagate : t -> unit
end

0 comments on commit 2b9c099

Please sign in to comment.