Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize stablehlo/vhlo graphs in tt-explorer #1721

Open
Tracked by #1705
tapspatel opened this issue Jan 7, 2025 · 0 comments
Open
Tracked by #1705

Visualize stablehlo/vhlo graphs in tt-explorer #1721

tapspatel opened this issue Jan 7, 2025 · 0 comments
Assignees

Comments

@tapspatel
Copy link
Contributor

Requesting additional feature in tt-explorer to be able to visualize stablehlo graphs. Reason for request is when bringing up models from tt-torch or tt-xla, they go through sub dialect steps, and it much easier to be able to visualize the decomposition of the graph between all the different dialects

Sample vhlo graph

module @jit_module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  vhlo.func_v1 @main(%arg0: !vhlo.tensor_v1<1x128x!vhlo.f32_v1>, %arg1: !vhlo.tensor_v1<128x64x!vhlo.f32_v1>, %arg2: !vhlo.tensor_v1<1x64x!vhlo.f32_v1>, %arg3: !vhlo.tensor_v1<1x64x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<1x64x!vhlo.f32_v1>) {
    %0 = "vhlo.dot_general_v2"(%arg0, %arg1) <{accumulation_type = #vhlo.type_v1<!vhlo.none_v1>, allow_imprecise_accumulation = #vhlo.type_v1<!vhlo.none_v1>, lhs_batching_dimensions = #vhlo.tensor_v1<dense<> : tensor<0xi64>>, lhs_component_count = #vhlo.type_v1<!vhlo.none_v1>, lhs_contracting_dimensions = #vhlo.tensor_v1<dense<1> : tensor<1xi64>>, lhs_precision_type = #vhlo.type_v1<!vhlo.none_v1>, num_primitive_operations = #vhlo.type_v1<!vhlo.none_v1>, precision_config = #vhlo.array_v1<[#vhlo<precision_v1 DEFAULT>, #vhlo<precision_v1 DEFAULT>]>, rhs_batching_dimensions = #vhlo.tensor_v1<dense<> : tensor<0xi64>>, rhs_component_count = #vhlo.type_v1<!vhlo.none_v1>, rhs_contracting_dimensions = #vhlo.tensor_v1<dense<0> : tensor<1xi64>>, rhs_precision_type = #vhlo.type_v1<!vhlo.none_v1>}> : (!vhlo.tensor_v1<1x128x!vhlo.f32_v1>, !vhlo.tensor_v1<128x64x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x64x!vhlo.f32_v1>
    %1 = "vhlo.multiply_v1"(%0, %arg2) : (!vhlo.tensor_v1<1x64x!vhlo.f32_v1>, !vhlo.tensor_v1<1x64x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x64x!vhlo.f32_v1>
    %2 = "vhlo.add_v1"(%1, %arg3) : (!vhlo.tensor_v1<1x64x!vhlo.f32_v1>, !vhlo.tensor_v1<1x64x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x64x!vhlo.f32_v1>
    "vhlo.return_v1"(%2) : (!vhlo.tensor_v1<1x64x!vhlo.f32_v1>) -> ()
  } {arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"mhlo.sharding"> = #vhlo.string_v1<"{replicated}">}>, #vhlo.dict_v1<{#vhlo.string_v1<"mhlo.sharding"> = #vhlo.string_v1<"{replicated}">}>, #vhlo.dict_v1<{#vhlo.string_v1<"mhlo.sharding"> = #vhlo.string_v1<"{replicated}">}>, #vhlo.dict_v1<{#vhlo.string_v1<"mhlo.sharding"> = #vhlo.string_v1<"{replicated}">}>]>, res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"jax.result_info"> = #vhlo.string_v1<"">}>]>, sym_visibility = #vhlo.string_v1<"public">}
}

Sample shlo graph

module @jit_module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x128xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<128x64xf32> {mhlo.sharding = "{replicated}"}, %arg2: tensor<1x64xf32> {mhlo.sharding = "{replicated}"}, %arg3: tensor<1x64xf32> {mhlo.sharding = "{replicated}"}) -> (tensor<1x64xf32> {jax.result_info = ""}) {
    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x128xf32>, tensor<128x64xf32>) -> tensor<1x64xf32>
    %1 = stablehlo.multiply %0, %arg2 : tensor<1x64xf32>
    %2 = stablehlo.add %1, %arg3 : tensor<1x64xf32>
    return %2 : tensor<1x64xf32>
  }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants