Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Expose OpenCL optimization pass #1353

Draft
wants to merge 20 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
399 changes: 337 additions & 62 deletions src/analysis_and_optimization/Memory_patterns.ml

Large diffs are not rendered by default.

81 changes: 68 additions & 13 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ let gen_inline_var (name : string) (id_var : string) =

let replace_fresh_local_vars (fname : string) stmt =
let f (m : (string, string) Core_kernel.Map.Poly.t) = function
| Stmt.Fixed.Pattern.Decl {decl_adtype; decl_type; decl_id; initialize} ->
| Stmt.Fixed.Pattern.Decl
{decl_adtype; decl_type; decl_id; initialize; mem_pattern} ->
let new_name =
match Map.Poly.find m decl_id with
| Some existing -> existing
| None -> gen_inline_var fname decl_id in
( Stmt.Fixed.Pattern.Decl
{decl_adtype; decl_id= new_name; decl_type; initialize}
{decl_adtype; decl_id= new_name; decl_type; initialize; mem_pattern}
, Map.Poly.set m ~key:decl_id ~data:new_name )
| Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} ->
let new_name =
Expand Down Expand Up @@ -201,7 +202,8 @@ let handle_early_returns (fname : string) opt_var stmt =
{ decl_adtype= DataOnly
; decl_id= returned
; decl_type= Sized SInt
; initialize= true }
; initialize= true
; mem_pattern= Mem_pattern.AoS }
; meta= Location_span.empty }
; Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -294,7 +296,8 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
(Type.to_unsized decl_type)
; decl_id= inline_return_name
; decl_type
; initialize= false } ]
; initialize= false
; mem_pattern= Mem_pattern.AoS } ]
(* We should minimize the code that's having its variables
replaced to avoid conflict with the (two) new dummy
variables introduced by inlining *)
Expand Down Expand Up @@ -828,7 +831,10 @@ and unenforce_initialize (lst : Stmt.Located.t list) =
| Stmt.Fixed.Pattern.Decl ({decl_id; decl_type; _} as decl_pat) -> (
let is_soa =
match decl_type with
| Type.Sized s -> SizedType.get_mem_pattern s = Mem_pattern.SoA
| Type.Sized s -> (
match SizedType.get_mem_pattern s with
| Mem_pattern.SoA | Mem_pattern.OpenCL -> true
| _ -> false )
| _ -> false in
match List.hd sub_lst with
| Some next_stmt -> (
Expand Down Expand Up @@ -975,7 +981,8 @@ let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) =
{ decl_adtype= Expr.Typed.adlevel_of key
; decl_id= data
; decl_type= Type.Unsized (Expr.Typed.type_of key)
; initialize= true }
; initialize= true
; mem_pattern= AoS }
; meta= Location_span.empty }
:: accum ) in
let lazy_code_motion_base i stmt =
Expand Down Expand Up @@ -1223,17 +1230,20 @@ let optimize_soa (mir : Program.Typed.t) =
(l : int) (aos_variables : string Set.Poly.t) =
let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in
match (mir_node l).pattern with
| stmt -> Memory_patterns.query_demotable_stmt aos_variables stmt in
| stmt ->
Memory_patterns.query_demotable_stmt Mem_pattern.SoA aos_variables stmt
in
let initial_variables =
List.fold ~init:Set.Poly.empty
~f:(Memory_patterns.query_initial_demotable_stmt false)
~f:(Memory_patterns.query_initial_demotable_stmt Mem_pattern.SoA false)
mir.reverse_mode_log_prob in
let mod_exprs aos_exits mod_expr =
Mir_utils.map_rec_expr
(Memory_patterns.modify_expr_pattern aos_exits)
(Memory_patterns.modify_expr_pattern Mem_pattern.SoA aos_exits)
mod_expr in
let modify_stmt_patt stmt_pattern variable_set =
Memory_patterns.modify_stmt_pattern stmt_pattern variable_set in
Memory_patterns.modify_stmt_pattern Mem_pattern.SoA stmt_pattern
variable_set in
let transform stmt =
optimize_minimal_variables ~gen_variables:gen_aos_variables
~update_expr:mod_exprs ~update_stmt:modify_stmt_patt ~initial_variables
Expand All @@ -1247,6 +1257,47 @@ let optimize_soa (mir : Program.Typed.t) =
in
{mir with reverse_mode_log_prob= transform' mir.reverse_mode_log_prob}

let optimize_opencl (mir : Program.Typed.t) =
let gen_aos_variables
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
(l : int) (aos_variables : string Set.Poly.t) =
let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in
match (mir_node l).pattern with
| stmt ->
Memory_patterns.query_demotable_stmt Mem_pattern.OpenCL aos_variables
stmt in
let initial_variables =
List.fold ~init:Set.Poly.empty
~f:(Memory_patterns.query_initial_demotable_stmt Mem_pattern.OpenCL false)
mir.reverse_mode_log_prob in
let mod_exprs aos_exits mod_expr =
Mir_utils.map_rec_expr
(Memory_patterns.modify_expr_pattern Mem_pattern.OpenCL aos_exits)
mod_expr in
let modify_stmt_patt stmt_pattern variable_set =
Memory_patterns.modify_stmt_pattern Mem_pattern.OpenCL stmt_pattern
variable_set in
let transform stmt =
optimize_minimal_variables ~gen_variables:gen_aos_variables
~update_expr:mod_exprs ~update_stmt:modify_stmt_patt ~initial_variables
stmt ~extra_variables:(fun _ -> initial_variables) in
let transform' s =
match transform {pattern= SList s; meta= Location_span.empty} with
| {pattern= SList (l : Stmt.Located.t list); _} -> l
| _ ->
Common.FatalError.fatal_error_msg
[%message "Something went wrong with program transformation packing!"]
in
let reverse_log_prob = transform' mir.reverse_mode_log_prob in
let opencl_data = Memory_patterns.extract_opencl_data reverse_log_prob in
let opencl_data_gen =
Memory_patterns.create_opencl_data opencl_data mir.prepare_data in
let reverse_log_prob' =
Memory_patterns.add_opencl_data opencl_data reverse_log_prob in
{ mir with
reverse_mode_log_prob= reverse_log_prob'
; prepare_data= opencl_data_gen }

(* Apparently you need to completely copy/paste type definitions between
ml and mli files?*)
type optimization_settings =
Expand All @@ -1264,7 +1315,8 @@ type optimization_settings =
; lazy_code_motion: bool
; optimize_ad_levels: bool
; preserve_stability: bool
; optimize_soa: bool }
; optimize_soa: bool
; optimize_opencl: bool }

let settings_const b =
{ function_inlining= b
Expand All @@ -1281,7 +1333,8 @@ let settings_const b =
; lazy_code_motion= b
; optimize_ad_levels= b
; preserve_stability= not b
; optimize_soa= b }
; optimize_soa= b
; optimize_opencl= false }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd want this to be enabled as part of all_optimizations so this should be b

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk, personally this is a very specific optimization for a piece of hardware. I'd rather the user explicitly flips it on


let all_optimizations : optimization_settings = settings_const true
let no_optimizations : optimization_settings = settings_const false
Expand All @@ -1306,7 +1359,8 @@ let level_optimizations (lvl : optimization_level) : optimization_settings =
; allow_uninitialized_decls= true
; optimize_ad_levels= false
; preserve_stability= false
; optimize_soa= true }
; optimize_soa= true
; optimize_opencl= false }
| Oexperimental -> all_optimizations

let optimization_suite ?(settings = all_optimizations) mir =
Expand Down Expand Up @@ -1349,6 +1403,7 @@ let optimization_suite ?(settings = all_optimizations) mir =
(* Book: Machine idioms and instruction combining *)
; (optimize_ad_levels, settings.optimize_ad_levels)
; (optimize_soa, settings.optimize_soa)
; (optimize_opencl, settings.optimize_opencl)
(*Remove decls immediately assigned to*)
; (allow_uninitialized_decls, settings.allow_uninitialized_decls)
(* Book: Machine idioms and instruction combining *)
Expand Down
3 changes: 2 additions & 1 deletion src/analysis_and_optimization/Optimize.mli
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ type optimization_settings =
; lazy_code_motion: bool
; optimize_ad_levels: bool
; preserve_stability: bool
; optimize_soa: bool }
; optimize_soa: bool
; optimize_opencl: bool }

val all_optimizations : optimization_settings
val no_optimizations : optimization_settings
Expand Down
14 changes: 11 additions & 3 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,13 @@ let create_decl_with_assign decl_id declc decl_type initial_value transform
() } in
let decl =
Stmt.
{ Fixed.pattern= Decl {decl_adtype; decl_id; decl_type; initialize= true}
{ Fixed.pattern=
Decl
{ decl_adtype
; decl_id
; decl_type
; initialize= true
; mem_pattern= Mem_pattern.AoS }
; meta= smeta } in
let rhs_assignment =
Option.map
Expand Down Expand Up @@ -600,7 +606,8 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
{ decl_adtype= Expr.Typed.adlevel_of iteratee'
; decl_id= loopvar.name
; decl_type= Unsized decl_type
; initialize= true } } in
; initialize= true
; mem_pattern= AoS } } in
let assignment var =
Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -683,7 +690,8 @@ let rec trans_sizedtype_decl declc tr name st =
{ decl_type= Sized SInt
; decl_id
; decl_adtype= DataOnly
; initialize= true }
; initialize= true
; mem_pattern= AoS }
; meta= e.meta.loc } in
let assign =
{ Stmt.Fixed.pattern=
Expand Down
7 changes: 7 additions & 0 deletions src/middle/Index.ml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ let apply ~default ~merge op (ind : 'a t) =
| Between (expr_top, expr_bottom) -> merge (op expr_top) (op expr_bottom)
| MultiIndex exprs -> op exprs

let map_expr ~f = function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already derive map for the type t, so this is equivalent to Index.map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this needs deleted

| All -> All
| Single ind_expr -> Single (f ind_expr)
| Upfrom ind_expr -> Upfrom ind_expr
| Between (expr_top, expr_bottom) -> Between (f expr_top, f expr_bottom)
| MultiIndex exprs -> MultiIndex (f exprs)

let folder (acc : string Set.Poly.t) op (ind : 'a t) : string Set.Poly.t =
match ind with
| All -> acc
Expand Down
10 changes: 7 additions & 3 deletions src/middle/Mem_pattern.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
open Core_kernel
open Core_kernel.Poly

(**
* This type represents whether or not an autodiff type can be represented
Expand All @@ -11,13 +10,18 @@ open Core_kernel.Poly
* (fyi a var in the C++ code is an alias for var_value<double>)
*
**)
type t = AoS | SoA [@@deriving sexp, compare, map, hash, fold, equal]
type t = AoS | SoA | OpenCL [@@deriving sexp, compare, map, hash, fold, equal]

let pp ppf = function
| AoS -> Fmt.string ppf "AoS"
| SoA -> Fmt.string ppf "SoA"
| OpenCL -> Fmt.string ppf "OpenCL"

let is_soa mem = match mem with SoA -> true | _ -> false
let is_aos mem = match mem with AoS -> true | _ -> false
let is_opencl mem = match mem with OpenCL -> true | _ -> false
Comment on lines +13 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a relationship between these right? All "OpenCL" memory layouts are automatically SoA, right?

Should is_soa take this into account?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All "OpenCL" memory layouts are automatically SoA, right?

Yes, but I think we should keep them different. For instance like printing the cpp we probably want to know the difference. I need to think about a right way to describe this


let lub_mem_pat lst =
let find_soa mem_pat = mem_pat = SoA in
let find_soa mem_pat = match mem_pat with SoA -> true | _ -> false in
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Equivalent to is_soa above

let any_soa = List.exists ~f:find_soa lst in
match any_soa with true -> SoA | false -> AoS
29 changes: 24 additions & 5 deletions src/middle/SizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ let rec get_mem_pattern st =
| SVector (mem, _) | SRowVector (mem, _) | SMatrix (mem, _, _) -> mem
| SArray (t, _) -> get_mem_pattern t

(*Given a sizedtype, demote it's mem pattern from SoA to AoS*)
(*Given a sizedtype, demote it's mem pattern from SoA or OpenCL to AoS*)
let rec demote_sizedtype_mem st =
match st with
| ( SInt | SReal | SComplex
Expand All @@ -201,12 +201,12 @@ let rec demote_sizedtype_mem st =
| SComplexMatrix (_, _) ) as ret ->
ret
| SArray (inner_type, dim) -> SArray (demote_sizedtype_mem inner_type, dim)
| SVector (SoA, dim) -> SVector (AoS, dim)
| SRowVector (SoA, dim) -> SRowVector (AoS, dim)
| SMatrix (SoA, dim1, dim2) -> SMatrix (AoS, dim1, dim2)
| SVector ((SoA | OpenCL), dim) -> SVector (AoS, dim)
| SRowVector ((SoA | OpenCL), dim) -> SRowVector (AoS, dim)
| SMatrix ((SoA | OpenCL), dim1, dim2) -> SMatrix (AoS, dim1, dim2)
| STuple subtypes -> STuple (List.map ~f:demote_sizedtype_mem subtypes)

(*Given a sizedtype, promote it's mem pattern from AoS to SoA*)
(*Given a sizedtype, promote it's mem pattern from AoS to SoA *)
let rec promote_sizedtype_mem st =
match st with
| SVector (AoS, dim) -> SVector (SoA, dim)
Expand All @@ -220,6 +220,15 @@ let modify_sizedtype_mem (mem_pattern : Mem_pattern.t) st =
match mem_pattern with
| AoS -> demote_sizedtype_mem st
| SoA -> promote_sizedtype_mem st
| OpenCL -> promote_sizedtype_mem st
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenCL is "promoted" to SoA by this function?


let rec promote_mem (mem_pattern : Mem_pattern.t) st =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it should be called replace_mem rather than promote

match st with
| SVector (_, dim) -> SVector (mem_pattern, dim)
| SRowVector (_, dim) -> SRowVector (mem_pattern, dim)
| SMatrix (_, dim1, dim2) -> SMatrix (mem_pattern, dim1, dim2)
| SArray (inner_type, dim) -> SArray (promote_mem mem_pattern inner_type, dim)
| _ -> st
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to consider things inside tuples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to allow OpenCL to use tuples then yes.

Also the Decls having a mem_pattern tag won't work because of tuples either :( I think we do just need to tag all sized types as having a memory pattern

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that’s too bad. You could do something like what we had to do for Autodiff level and have a tuple specific variant in the type, but I wasn’t super happy with that either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah eod I think it's going to look a little odd in some places but I think it's fine to just add it to the sized types


let rec has_mem_pattern = function
| SInt | SReal | SComplex | SComplexVector _ | SComplexRowVector _
Expand All @@ -229,6 +238,16 @@ let rec has_mem_pattern = function
| SArray (t, _) -> has_mem_pattern t
| STuple subtypes -> List.exists ~f:has_mem_pattern subtypes

let is_eigen_type st =
match st with
| (SVector (mem, _) | SRowVector (mem, _) | SMatrix (mem, _, _))
when Mem_pattern.is_opencl mem ->
false
| SVector _ | SRowVector _ | SMatrix _ | SComplexRowVector _
|SComplexVector _ | SComplexMatrix _ ->
true
| _ -> false

Comment on lines +253 to +262
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel a little nervous that SizedType.is_eigen_type st and UnsizedType.is_eigen_type (SizedType.to_unsized st) would return different things in some circumstances. I suppose it depends on how/where both of them are used, but I'd feel a bit more confident with a is_eigen_type which matches Unsized and using is_eigen st && not (is_opencl st) where needed.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree it's very shaky. I think splitting it out it a better idea

(** The inverse of [get_array_dims]
*)
let build_sarray dims st =
Expand Down
Loading