Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FourierNeuralOperators segfault #185

Open
avik-pal opened this issue Dec 13, 2024 · 1 comment
Open

FourierNeuralOperators segfault #185

avik-pal opened this issue Dec 13, 2024 · 1 comment

Comments

@avik-pal
Copy link
Collaborator

Unoptimized MLIR

module {
  func.func private @"+_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar1"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar2"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar3"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar4"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar5"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar6"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar2(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar7"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar8"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar3(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar9"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar4(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar10"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @abs2_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.abs %0 : tensor<f32>
    %2 = stablehlo.multiply %1, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4 : tensor<f32>, tensor<f32>
  }
  func.func private @"Const{typeof(sumabs2first)}(Main.sumabs2first)_autodiff"(%arg0: tensor<2x64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64xf32>, %arg4: tensor<16x64x64xcomplex<f32>>, %arg5: tensor<64x64xf32>, %arg6: tensor<64xf32>, %arg7: tensor<16x64x64xcomplex<f32>>, %arg8: tensor<64x64xf32>, %arg9: tensor<64xf32>, %arg10: tensor<16x64x64xcomplex<f32>>, %arg11: tensor<64x64xf32>, %arg12: tensor<64xf32>, %arg13: tensor<16x64x64xcomplex<f32>>, %arg14: tensor<64x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x1xf32>, %arg17: tensor<1xf32>, %arg18: tensor<5x32x2xf32>) -> (tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %4 = stablehlo.transpose %arg4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %5 = stablehlo.transpose %arg5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %6 = stablehlo.transpose %arg6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %7 = stablehlo.transpose %arg7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %10 = stablehlo.transpose %arg10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %12 = stablehlo.transpose %arg12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %13 = stablehlo.transpose %arg13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %14 = stablehlo.transpose %arg14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
    %15 = stablehlo.transpose %arg15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %16 = stablehlo.transpose %arg16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
    %17 = stablehlo.transpose %arg17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %18 = stablehlo.transpose %arg18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %19 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
    %20 = stablehlo.reshape %19 : (tensor<5x32x2xf32>) -> tensor<160x2xf32>
    %21 = stablehlo.transpose %20, dims = [1, 0] : (tensor<160x2xf32>) -> tensor<2x160xf32>
    %22 = stablehlo.dot_general %0, %21, contracting_dims = [1] x [0] : (tensor<64x2xf32>, tensor<2x160xf32>) -> tensor<64x160xf32>
    %23 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %24:3 = enzyme.batch @"+_broadcast_scalar"(%22, %23) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %25 = stablehlo.transpose %24#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %26 = stablehlo.reshape %25 : (tensor<160x64xf32>) -> tensor<160x64xf32>
    %27 = stablehlo.transpose %26, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %28 = stablehlo.dot_general %2, %27, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
    %29 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %30:3 = enzyme.batch @"+_broadcast_scalar1"(%28, %29) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %31 = stablehlo.transpose %24#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %32 = stablehlo.reshape %31 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %33 = stablehlo.transpose %32, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %34 = stablehlo.transpose %33, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
    %35 = stablehlo.convert %34 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
    %36 = stablehlo.transpose %35, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %37 = stablehlo.fft %36, type =  FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %38 = stablehlo.transpose %37, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_1 = stablehlo.constant dense<0> : tensor<i64>
    %c_2 = stablehlo.constant dense<0> : tensor<i64>
    %39 = stablehlo.dynamic_slice %38, %c, %c_1, %c_2, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
    %40 = stablehlo.transpose %39, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %41 = stablehlo.reshape %40 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %42 = stablehlo.transpose %41, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %43 = stablehlo.transpose %42, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %cst_3 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
    %44 = stablehlo.transpose %4, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %45 = stablehlo.transpose %43, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %46 = stablehlo.dot_general %44, %45, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %47 = stablehlo.transpose %46, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %48 = stablehlo.transpose %47, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %cst_4 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
    %49 = stablehlo.transpose %48, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %50 = stablehlo.reshape %49 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %51 = stablehlo.transpose %50, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %52 = stablehlo.pad %51, %cst_4, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %53 = stablehlo.transpose %52, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %54 = stablehlo.fft %53, type =  IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %55 = stablehlo.transpose %54, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %56 = stablehlo.real %55 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
    %57 = stablehlo.transpose %56, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
    %58 = stablehlo.transpose %30#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %59 = stablehlo.reshape %58 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %60 = stablehlo.transpose %59, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %61:3 = enzyme.batch @"+_broadcast_scalar2"(%60, %57) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %62:2 = enzyme.batch @gelu_broadcast_scalar(%61#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %63 = stablehlo.transpose %62#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
    %64 = stablehlo.reshape %63 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
    %65 = stablehlo.transpose %64, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %66 = stablehlo.dot_general %5, %65, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
    %67 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %68:3 = enzyme.batch @"+_broadcast_scalar3"(%66, %67) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %69 = stablehlo.transpose %62#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
    %70 = stablehlo.convert %69 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
    %71 = stablehlo.transpose %70, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %72 = stablehlo.fft %71, type =  FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %73 = stablehlo.transpose %72, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %c_6 = stablehlo.constant dense<0> : tensor<i64>
    %c_7 = stablehlo.constant dense<0> : tensor<i64>
    %c_8 = stablehlo.constant dense<0> : tensor<i64>
    %74 = stablehlo.dynamic_slice %73, %c_6, %c_7, %c_8, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
    %75 = stablehlo.transpose %74, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %76 = stablehlo.reshape %75 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %77 = stablehlo.transpose %76, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %78 = stablehlo.transpose %77, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %cst_9 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
    %79 = stablehlo.transpose %7, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %80 = stablehlo.transpose %78, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %81 = stablehlo.dot_general %79, %80, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %82 = stablehlo.transpose %81, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %83 = stablehlo.transpose %82, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %cst_10 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
    %84 = stablehlo.transpose %83, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %85 = stablehlo.reshape %84 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %86 = stablehlo.transpose %85, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %87 = stablehlo.pad %86, %cst_10, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %88 = stablehlo.transpose %87, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %89 = stablehlo.fft %88, type =  IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %90 = stablehlo.transpose %89, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %91 = stablehlo.real %90 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
    %92 = stablehlo.transpose %91, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
    %93 = stablehlo.transpose %68#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %94 = stablehlo.reshape %93 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %95 = stablehlo.transpose %94, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %96:3 = enzyme.batch @"+_broadcast_scalar4"(%95, %92) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %97:2 = enzyme.batch @gelu_broadcast_scalar1(%96#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %98 = stablehlo.transpose %97#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
    %99 = stablehlo.reshape %98 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
    %100 = stablehlo.transpose %99, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %101 = stablehlo.dot_general %8, %100, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
    %102 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %103:3 = enzyme.batch @"+_broadcast_scalar5"(%101, %102) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %104 = stablehlo.transpose %97#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
    %105 = stablehlo.convert %104 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
    %106 = stablehlo.transpose %105, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %107 = stablehlo.fft %106, type =  FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %108 = stablehlo.transpose %107, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %c_12 = stablehlo.constant dense<0> : tensor<i64>
    %c_13 = stablehlo.constant dense<0> : tensor<i64>
    %c_14 = stablehlo.constant dense<0> : tensor<i64>
    %109 = stablehlo.dynamic_slice %108, %c_12, %c_13, %c_14, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
    %110 = stablehlo.transpose %109, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %111 = stablehlo.reshape %110 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %112 = stablehlo.transpose %111, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %113 = stablehlo.transpose %112, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %cst_15 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
    %114 = stablehlo.transpose %10, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %115 = stablehlo.transpose %113, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %116 = stablehlo.dot_general %114, %115, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %117 = stablehlo.transpose %116, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %118 = stablehlo.transpose %117, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %cst_16 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
    %119 = stablehlo.transpose %118, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %120 = stablehlo.reshape %119 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %121 = stablehlo.transpose %120, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %122 = stablehlo.pad %121, %cst_16, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %123 = stablehlo.transpose %122, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %124 = stablehlo.fft %123, type =  IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %125 = stablehlo.transpose %124, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %126 = stablehlo.real %125 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
    %127 = stablehlo.transpose %126, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
    %128 = stablehlo.transpose %103#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %129 = stablehlo.reshape %128 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %130 = stablehlo.transpose %129, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %131:3 = enzyme.batch @"+_broadcast_scalar6"(%130, %127) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %132:2 = enzyme.batch @gelu_broadcast_scalar2(%131#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %cst_17 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %133 = stablehlo.transpose %132#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
    %134 = stablehlo.reshape %133 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
    %135 = stablehlo.transpose %134, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %136 = stablehlo.dot_general %11, %135, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
    %137 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %138:3 = enzyme.batch @"+_broadcast_scalar7"(%136, %137) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %139 = stablehlo.transpose %132#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
    %140 = stablehlo.convert %139 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
    %141 = stablehlo.transpose %140, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %142 = stablehlo.fft %141, type =  FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %143 = stablehlo.transpose %142, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %c_18 = stablehlo.constant dense<0> : tensor<i64>
    %c_19 = stablehlo.constant dense<0> : tensor<i64>
    %c_20 = stablehlo.constant dense<0> : tensor<i64>
    %144 = stablehlo.dynamic_slice %143, %c_18, %c_19, %c_20, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
    %145 = stablehlo.transpose %144, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %146 = stablehlo.reshape %145 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %147 = stablehlo.transpose %146, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %148 = stablehlo.transpose %147, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %cst_21 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
    %149 = stablehlo.transpose %13, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %150 = stablehlo.transpose %148, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %151 = stablehlo.dot_general %149, %150, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %152 = stablehlo.transpose %151, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %153 = stablehlo.transpose %152, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %cst_22 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
    %154 = stablehlo.transpose %153, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %155 = stablehlo.reshape %154 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %156 = stablehlo.transpose %155, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %157 = stablehlo.pad %156, %cst_22, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %158 = stablehlo.transpose %157, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %159 = stablehlo.fft %158, type =  IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %160 = stablehlo.transpose %159, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %161 = stablehlo.real %160 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
    %162 = stablehlo.transpose %161, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
    %163 = stablehlo.transpose %138#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %164 = stablehlo.reshape %163 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %165 = stablehlo.transpose %164, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %166:3 = enzyme.batch @"+_broadcast_scalar8"(%165, %162) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %167:2 = enzyme.batch @gelu_broadcast_scalar3(%166#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<128x160xf32>
    %168 = stablehlo.transpose %167#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
    %169 = stablehlo.reshape %168 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
    %170 = stablehlo.transpose %169, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %171 = stablehlo.dot_general %14, %170, contracting_dims = [1] x [0] : (tensor<128x64xf32>, tensor<64x160xf32>) -> tensor<128x160xf32>
    %172 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %173 = stablehlo.reshape %172 : (tensor<128xf32>) -> tensor<1x128xf32>
    %174 = stablehlo.transpose %173, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %175 = stablehlo.broadcast_in_dim %174, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x160xf32>
    %176:3 = enzyme.batch @"+_broadcast_scalar9"(%171, %175) {batch_shape = array<i64: 128, 160>} : (tensor<128x160xf32>, tensor<128x160xf32>) -> (tensor<128x160xf32>, tensor<128x160xf32>, tensor<128x160xf32>)
    %177:2 = enzyme.batch @gelu_broadcast_scalar4(%176#0) {batch_shape = array<i64: 128, 160>} : (tensor<128x160xf32>) -> (tensor<128x160xf32>, tensor<128x160xf32>)
    %cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<1x160xf32>
    %178 = stablehlo.transpose %177#0, dims = [1, 0] : (tensor<128x160xf32>) -> tensor<160x128xf32>
    %179 = stablehlo.reshape %178 : (tensor<160x128xf32>) -> tensor<160x128xf32>
    %180 = stablehlo.transpose %179, dims = [1, 0] : (tensor<160x128xf32>) -> tensor<128x160xf32>
    %181 = stablehlo.dot_general %16, %180, contracting_dims = [1] x [0] : (tensor<1x128xf32>, tensor<128x160xf32>) -> tensor<1x160xf32>
    %182 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<1xf32>) -> tensor<1x160xf32>
    %183:3 = enzyme.batch @"+_broadcast_scalar10"(%181, %182) {batch_shape = array<i64: 1, 160>} : (tensor<1x160xf32>, tensor<1x160xf32>) -> (tensor<1x160xf32>, tensor<1x160xf32>, tensor<1x160xf32>)
    %184 = stablehlo.transpose %183#0, dims = [1, 0] : (tensor<1x160xf32>) -> tensor<160x1xf32>
    %185 = stablehlo.reshape %184 : (tensor<160x1xf32>) -> tensor<5x32x1xf32>
    %186 = stablehlo.transpose %185, dims = [2, 1, 0] : (tensor<5x32x1xf32>) -> tensor<1x32x5xf32>
    %cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %187:2 = enzyme.batch @abs2_broadcast_scalar(%186) {batch_shape = array<i64: 1, 32, 5>} : (tensor<1x32x5xf32>) -> (tensor<1x32x5xf32>, tensor<1x32x5xf32>)
    %188 = stablehlo.reduce(%187#0 init: %cst_25) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<1x32x5xf32>, tensor<f32>) -> tensor<f32>
    %189 = stablehlo.transpose %188, dims = [] : (tensor<f32>) -> tensor<f32>
    %190 = stablehlo.transpose %0, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %191 = stablehlo.transpose %1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %192 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %193 = stablehlo.transpose %3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %194 = stablehlo.transpose %4, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %195 = stablehlo.transpose %5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %196 = stablehlo.transpose %6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %197 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %198 = stablehlo.transpose %8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %199 = stablehlo.transpose %9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %200 = stablehlo.transpose %10, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %201 = stablehlo.transpose %11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %202 = stablehlo.transpose %12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %203 = stablehlo.transpose %13, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %204 = stablehlo.transpose %14, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %205 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %206 = stablehlo.transpose %16, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %207 = stablehlo.transpose %17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %208 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
    return %189, %190, %191, %192, %193, %194, %195, %196, %197, %198, %199, %200, %201, %202, %203, %204, %205, %206, %207, %208 : tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>
  }
  func.func @main(%arg0: tensor<2x64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64xf32>, %arg4: tensor<16x64x64xcomplex<f32>>, %arg5: tensor<64x64xf32>, %arg6: tensor<64xf32>, %arg7: tensor<16x64x64xcomplex<f32>>, %arg8: tensor<64x64xf32>, %arg9: tensor<64xf32>, %arg10: tensor<16x64x64xcomplex<f32>>, %arg11: tensor<64x64xf32>, %arg12: tensor<64xf32>, %arg13: tensor<16x64x64xcomplex<f32>>, %arg14: tensor<64x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x1xf32>, %arg17: tensor<1xf32>, %arg18: tensor<5x32x2xf32>) -> (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %4 = stablehlo.transpose %arg4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %5 = stablehlo.transpose %arg5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %6 = stablehlo.transpose %arg6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %7 = stablehlo.transpose %arg7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %10 = stablehlo.transpose %arg10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %12 = stablehlo.transpose %arg12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %13 = stablehlo.transpose %arg13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %14 = stablehlo.transpose %arg14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
    %15 = stablehlo.transpose %arg15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %16 = stablehlo.transpose %arg16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
    %17 = stablehlo.transpose %arg17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %18 = stablehlo.transpose %arg18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x2xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_3 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_9 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
    %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_12 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<128x64xf32>
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<128xf32>
    %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<1x128xf32>
    %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<1xf32>
    %cst_17 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %19 = stablehlo.transpose %0, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %20 = stablehlo.transpose %1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %21 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %22 = stablehlo.transpose %3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %23 = stablehlo.transpose %4, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %24 = stablehlo.transpose %5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %25 = stablehlo.transpose %6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %26 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %27 = stablehlo.transpose %8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %28 = stablehlo.transpose %9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %29 = stablehlo.transpose %10, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %30 = stablehlo.transpose %11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %31 = stablehlo.transpose %12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %32 = stablehlo.transpose %13, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %33 = stablehlo.transpose %14, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %34 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %35 = stablehlo.transpose %16, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %36 = stablehlo.transpose %17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %37 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
    %38 = stablehlo.transpose %cst_17, dims = [] : (tensor<f32>) -> tensor<f32>
    %39 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %40 = stablehlo.transpose %cst_0, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %41 = stablehlo.transpose %cst_1, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %42 = stablehlo.transpose %cst_2, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %43 = stablehlo.transpose %cst_3, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %44 = stablehlo.transpose %cst_4, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %45 = stablehlo.transpose %cst_5, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %46 = stablehlo.transpose %cst_6, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %47 = stablehlo.transpose %cst_7, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %48 = stablehlo.transpose %cst_8, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %49 = stablehlo.transpose %cst_9, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %50 = stablehlo.transpose %cst_10, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %51 = stablehlo.transpose %cst_11, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %52 = stablehlo.transpose %cst_12, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %53 = stablehlo.transpose %cst_13, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %54 = stablehlo.transpose %cst_14, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %55 = stablehlo.transpose %cst_15, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %56 = stablehlo.transpose %cst_16, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %57:37 = enzyme.autodiff @"Const{typeof(sumabs2first)}(Main.sumabs2first)_autodiff"(%19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>, tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>) -> (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>)
    %58 = stablehlo.transpose %57#0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
    %59 = stablehlo.transpose %57#1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %60 = stablehlo.transpose %57#2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %61 = stablehlo.transpose %57#3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %62 = stablehlo.transpose %57#4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %63 = stablehlo.transpose %57#5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %64 = stablehlo.transpose %57#6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %65 = stablehlo.transpose %57#7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %66 = stablehlo.transpose %57#8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %67 = stablehlo.transpose %57#9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %68 = stablehlo.transpose %57#10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %69 = stablehlo.transpose %57#11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %70 = stablehlo.transpose %57#12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %71 = stablehlo.transpose %57#13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %72 = stablehlo.transpose %57#14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
    %73 = stablehlo.transpose %57#15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %74 = stablehlo.transpose %57#16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
    %75 = stablehlo.transpose %57#17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %76 = stablehlo.transpose %57#18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32>
    %77 = stablehlo.transpose %57#19, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
    %78 = stablehlo.transpose %57#20, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %79 = stablehlo.transpose %57#21, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %80 = stablehlo.transpose %57#22, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %81 = stablehlo.transpose %57#23, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %82 = stablehlo.transpose %57#24, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %83 = stablehlo.transpose %57#25, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %84 = stablehlo.transpose %57#26, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %85 = stablehlo.transpose %57#27, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %86 = stablehlo.transpose %57#28, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %87 = stablehlo.transpose %57#29, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %88 = stablehlo.transpose %57#30, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %89 = stablehlo.transpose %57#31, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %90 = stablehlo.transpose %57#32, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %91 = stablehlo.transpose %57#33, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
    %92 = stablehlo.transpose %57#34, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %93 = stablehlo.transpose %57#35, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
    %94 = stablehlo.transpose %57#36, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %95 = stablehlo.transpose %77, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %96 = stablehlo.transpose %78, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %97 = stablehlo.transpose %79, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %98 = stablehlo.transpose %80, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %99 = stablehlo.transpose %81, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %100 = stablehlo.transpose %82, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %101 = stablehlo.transpose %83, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %102 = stablehlo.transpose %84, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %103 = stablehlo.transpose %85, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %104 = stablehlo.transpose %86, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %105 = stablehlo.transpose %87, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %106 = stablehlo.transpose %88, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %107 = stablehlo.transpose %89, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %108 = stablehlo.transpose %90, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %109 = stablehlo.transpose %91, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %110 = stablehlo.transpose %92, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %111 = stablehlo.transpose %93, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %112 = stablehlo.transpose %94, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %113 = stablehlo.transpose %58, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %114 = stablehlo.transpose %59, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %115 = stablehlo.transpose %60, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %116 = stablehlo.transpose %61, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %117 = stablehlo.transpose %62, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %118 = stablehlo.transpose %63, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %119 = stablehlo.transpose %64, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %120 = stablehlo.transpose %65, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %121 = stablehlo.transpose %66, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %122 = stablehlo.transpose %67, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %123 = stablehlo.transpose %68, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %124 = stablehlo.transpose %69, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %125 = stablehlo.transpose %70, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %126 = stablehlo.transpose %71, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %127 = stablehlo.transpose %72, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %128 = stablehlo.transpose %73, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %129 = stablehlo.transpose %74, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %130 = stablehlo.transpose %75, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %131 = stablehlo.transpose %76, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
    return %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127, %128, %129, %130, %131 : tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>
  }
}

Error Message with debug build

enzymexlamlir-opt: external/llvm-project/mlir/lib/IR/Types.cpp:134: unsigned int mlir::Type::getIntOrFloatBitWidth() const: Assertion `isIntOrFloat() && "only integers and floats have a bitwidth"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: ./bazel-bin/enzymexlamlir-opt --enzyme-hlo-opt --enzyme-batch --enzyme envs/fno.mlir
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  enzymexlamlir-opt 0x000064943eed9fce
1  enzymexlamlir-opt 0x000064943eeda383
2  enzymexlamlir-opt 0x000064943eed7a84
3  enzymexlamlir-opt 0x000064943eed9af3
4  libc.so.6         0x00007709d944c1d0
5  libc.so.6         0x00007709d94a53f4
6  libc.so.6         0x00007709d944c120 gsignal + 32
7  libc.so.6         0x00007709d94334c3 abort + 223
8  libc.so.6         0x00007709d94333df
9  libc.so.6         0x00007709d9444177
10 enzymexlamlir-opt 0x000064943eba05c0
11 enzymexlamlir-opt 0x000064943b65a7c6
12 enzymexlamlir-opt 0x000064943b5a3f09
13 enzymexlamlir-opt 0x000064943b6c2f28
14 enzymexlamlir-opt 0x000064943e5ccfeb
15 enzymexlamlir-opt 0x000064943e5cda42
16 enzymexlamlir-opt 0x000064943af1b4d8
17 enzymexlamlir-opt 0x000064943e5d0d27
18 enzymexlamlir-opt 0x000064943e5cd769
19 enzymexlamlir-opt 0x000064943e5be495
20 enzymexlamlir-opt 0x000064943e5bf631
21 enzymexlamlir-opt 0x000064943e5c094b
22 enzymexlamlir-opt 0x000064943af1b4d8
23 enzymexlamlir-opt 0x000064943e5c06cf
24 enzymexlamlir-opt 0x000064943e5bf8fc
25 enzymexlamlir-opt 0x000064943e5bfab6
26 enzymexlamlir-opt 0x000064943b3e5548
27 enzymexlamlir-opt 0x000064943b5b2f6a
28 enzymexlamlir-opt 0x000064943ea6f1f9
29 enzymexlamlir-opt 0x000064943ea72eda
30 enzymexlamlir-opt 0x000064943af1b4d8
31 enzymexlamlir-opt 0x000064943ea7895f
32 enzymexlamlir-opt 0x000064943ea6f617
33 enzymexlamlir-opt 0x000064943ea6f959
34 enzymexlamlir-opt 0x000064943ea7180a
35 enzymexlamlir-opt 0x000064943ea715f9
36 enzymexlamlir-opt 0x000064943aefe9fb
37 enzymexlamlir-opt 0x000064943aeff1dd
38 enzymexlamlir-opt 0x000064943aeff8d0
39 enzymexlamlir-opt 0x000064943af00ae7
40 enzymexlamlir-opt 0x000064943edc5a42
41 enzymexlamlir-opt 0x000064943edc5198
42 enzymexlamlir-opt 0x000064943aeffa5e
43 enzymexlamlir-opt 0x000064943aeffd6c
44 enzymexlamlir-opt 0x000064943aeffffd
45 enzymexlamlir-opt 0x000064943aeb2642
46 libc.so.6         0x00007709d9434e08
47 libc.so.6         0x00007709d9434ecc __libc_start_main + 140
48 enzymexlamlir-opt 0x000064943aeb20b5
[1]    100225 IOT instruction (core dumped)  ./bazel-bin/enzymexlamlir-opt --enzyme-hlo-opt --enzyme-batch --enzyme 
@avik-pal
Copy link
Collaborator Author

xref: SciML/NeuralOperators.jl#52

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant