We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
module { func.func private @"+_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar1"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar2"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @gelu_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32> %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32> %1 = stablehlo.multiply %0, %0 : tensor<f32> %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32> %2 = stablehlo.multiply %1, %cst : tensor<f32> %3 = stablehlo.add %2, %cst_1 : tensor<f32> %4 = stablehlo.multiply %cst_0, %0 : tensor<f32> %5 = stablehlo.multiply %4, %3 : tensor<f32> %6 = stablehlo.logistic %5 : tensor<f32> %7 = stablehlo.multiply %0, %6 : tensor<f32> %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32> %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> return %8, %9 : tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar3"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar4"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @gelu_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32> %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32> %1 = stablehlo.multiply %0, %0 : tensor<f32> %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32> %2 = stablehlo.multiply %1, %cst : tensor<f32> %3 = stablehlo.add %2, %cst_1 : tensor<f32> %4 = stablehlo.multiply %cst_0, %0 : tensor<f32> %5 = stablehlo.multiply %4, %3 : tensor<f32> %6 = stablehlo.logistic %5 : tensor<f32> %7 = stablehlo.multiply %0, %6 : tensor<f32> %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32> %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> return %8, %9 : tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar5"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar6"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @gelu_broadcast_scalar2(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32> %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32> %1 = stablehlo.multiply %0, %0 : tensor<f32> %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32> %2 = stablehlo.multiply %1, %cst : tensor<f32> %3 = stablehlo.add %2, %cst_1 : tensor<f32> %4 = stablehlo.multiply %cst_0, %0 : tensor<f32> %5 = stablehlo.multiply %4, %3 : tensor<f32> %6 = stablehlo.logistic %5 : tensor<f32> %7 = stablehlo.multiply %0, %6 : tensor<f32> %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32> %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> return %8, %9 : tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar7"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar8"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @gelu_broadcast_scalar3(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32> %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32> %1 = stablehlo.multiply %0, %0 : tensor<f32> %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32> %2 = stablehlo.multiply %1, %cst : tensor<f32> %3 = stablehlo.add %2, %cst_1 : tensor<f32> %4 = stablehlo.multiply %cst_0, %0 : tensor<f32> %5 = stablehlo.multiply %4, %3 : tensor<f32> %6 = stablehlo.logistic %5 : tensor<f32> %7 = stablehlo.multiply %0, %6 : tensor<f32> %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32> %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> return %8, %9 : tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar9"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @gelu_broadcast_scalar4(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32> %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32> %1 = stablehlo.multiply %0, %0 : tensor<f32> %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32> %2 = stablehlo.multiply %1, %cst : tensor<f32> %3 = stablehlo.add %2, %cst_1 : tensor<f32> %4 = stablehlo.multiply %cst_0, %0 : tensor<f32> %5 = stablehlo.multiply %4, %3 : tensor<f32> %6 = stablehlo.logistic %5 : tensor<f32> %7 = stablehlo.multiply %0, %6 : tensor<f32> %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32> %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> return %8, %9 : tensor<f32>, tensor<f32> } func.func private @"+_broadcast_scalar10"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32> %2 = stablehlo.add %0, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32> } func.func private @abs2_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.abs %0 : tensor<f32> %2 = stablehlo.multiply %1, %1 : tensor<f32> %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32> %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> return %3, %4 : tensor<f32>, tensor<f32> } func.func private @"Const{typeof(sumabs2first)}(Main.sumabs2first)_autodiff"(%arg0: tensor<2x64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64xf32>, %arg4: tensor<16x64x64xcomplex<f32>>, %arg5: tensor<64x64xf32>, %arg6: tensor<64xf32>, %arg7: tensor<16x64x64xcomplex<f32>>, %arg8: tensor<64x64xf32>, %arg9: tensor<64xf32>, %arg10: tensor<16x64x64xcomplex<f32>>, %arg11: tensor<64x64xf32>, %arg12: tensor<64xf32>, %arg13: tensor<16x64x64xcomplex<f32>>, %arg14: tensor<64x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x1xf32>, %arg17: tensor<1xf32>, %arg18: tensor<5x32x2xf32>) -> (tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>) { %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32> %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %4 = stablehlo.transpose %arg4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %5 = stablehlo.transpose %arg5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %6 = stablehlo.transpose %arg6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %7 = stablehlo.transpose %arg7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %10 = stablehlo.transpose %arg10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %12 = stablehlo.transpose %arg12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %13 = stablehlo.transpose %arg13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %14 = stablehlo.transpose %arg14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32> %15 = stablehlo.transpose %arg15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %16 = stablehlo.transpose %arg16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32> %17 = stablehlo.transpose %arg17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32> %18 = stablehlo.transpose %arg18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32> %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32> %19 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32> %20 = stablehlo.reshape %19 : (tensor<5x32x2xf32>) -> tensor<160x2xf32> %21 = stablehlo.transpose %20, dims = [1, 0] : (tensor<160x2xf32>) -> tensor<2x160xf32> %22 = stablehlo.dot_general %0, %21, contracting_dims = [1] x [0] : (tensor<64x2xf32>, tensor<2x160xf32>) -> tensor<64x160xf32> %23 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32> %24:3 = enzyme.batch @"+_broadcast_scalar"(%22, %23) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>) %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32> %25 = stablehlo.transpose %24#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32> %26 = stablehlo.reshape %25 : (tensor<160x64xf32>) -> tensor<160x64xf32> %27 = stablehlo.transpose %26, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32> %28 = stablehlo.dot_general %2, %27, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32> %29 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32> %30:3 = enzyme.batch @"+_broadcast_scalar1"(%28, %29) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>) %31 = stablehlo.transpose %24#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32> %32 = stablehlo.reshape %31 : (tensor<160x64xf32>) -> tensor<5x32x64xf32> %33 = stablehlo.transpose %32, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32> %34 = stablehlo.transpose %33, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32> %35 = stablehlo.convert %34 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>> %36 = stablehlo.transpose %35, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %37 = stablehlo.fft %36, type = FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %38 = stablehlo.transpose %37, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>> %c = stablehlo.constant dense<0> : tensor<i64> %c_1 = stablehlo.constant dense<0> : tensor<i64> %c_2 = stablehlo.constant dense<0> : tensor<i64> %39 = stablehlo.dynamic_slice %38, %c, %c_1, %c_2, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>> %40 = stablehlo.transpose %39, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %41 = stablehlo.reshape %40 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %42 = stablehlo.transpose %41, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %43 = stablehlo.transpose %42, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>> %cst_3 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>> %44 = stablehlo.transpose %4, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %45 = stablehlo.transpose %43, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %46 = stablehlo.dot_general %44, %45, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %47 = stablehlo.transpose %46, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>> %48 = stablehlo.transpose %47, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %cst_4 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>> %49 = stablehlo.transpose %48, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %50 = stablehlo.reshape %49 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %51 = stablehlo.transpose %50, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %52 = stablehlo.pad %51, %cst_4, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>> %53 = stablehlo.transpose %52, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %54 = stablehlo.fft %53, type = IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %55 = stablehlo.transpose %54, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>> %56 = stablehlo.real %55 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32> %57 = stablehlo.transpose %56, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32> %58 = stablehlo.transpose %30#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32> %59 = stablehlo.reshape %58 : (tensor<160x64xf32>) -> tensor<5x32x64xf32> %60 = stablehlo.transpose %59, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32> %61:3 = enzyme.batch @"+_broadcast_scalar2"(%60, %57) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>) %62:2 = enzyme.batch @gelu_broadcast_scalar(%61#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>) %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32> %63 = stablehlo.transpose %62#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32> %64 = stablehlo.reshape %63 : (tensor<5x32x64xf32>) -> tensor<160x64xf32> %65 = stablehlo.transpose %64, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32> %66 = stablehlo.dot_general %5, %65, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32> %67 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32> %68:3 = enzyme.batch @"+_broadcast_scalar3"(%66, %67) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>) %69 = stablehlo.transpose %62#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32> %70 = stablehlo.convert %69 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>> %71 = stablehlo.transpose %70, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %72 = stablehlo.fft %71, type = FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %73 = stablehlo.transpose %72, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>> %c_6 = stablehlo.constant dense<0> : tensor<i64> %c_7 = stablehlo.constant dense<0> : tensor<i64> %c_8 = stablehlo.constant dense<0> : tensor<i64> %74 = stablehlo.dynamic_slice %73, %c_6, %c_7, %c_8, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>> %75 = stablehlo.transpose %74, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %76 = stablehlo.reshape %75 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %77 = stablehlo.transpose %76, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %78 = stablehlo.transpose %77, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>> %cst_9 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>> %79 = stablehlo.transpose %7, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %80 = stablehlo.transpose %78, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %81 = stablehlo.dot_general %79, %80, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %82 = stablehlo.transpose %81, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>> %83 = stablehlo.transpose %82, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %cst_10 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>> %84 = stablehlo.transpose %83, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %85 = stablehlo.reshape %84 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %86 = stablehlo.transpose %85, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %87 = stablehlo.pad %86, %cst_10, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>> %88 = stablehlo.transpose %87, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %89 = stablehlo.fft %88, type = IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %90 = stablehlo.transpose %89, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>> %91 = stablehlo.real %90 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32> %92 = stablehlo.transpose %91, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32> %93 = stablehlo.transpose %68#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32> %94 = stablehlo.reshape %93 : (tensor<160x64xf32>) -> tensor<5x32x64xf32> %95 = stablehlo.transpose %94, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32> %96:3 = enzyme.batch @"+_broadcast_scalar4"(%95, %92) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>) %97:2 = enzyme.batch @gelu_broadcast_scalar1(%96#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>) %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32> %98 = stablehlo.transpose %97#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32> %99 = stablehlo.reshape %98 : (tensor<5x32x64xf32>) -> tensor<160x64xf32> %100 = stablehlo.transpose %99, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32> %101 = stablehlo.dot_general %8, %100, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32> %102 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32> %103:3 = enzyme.batch @"+_broadcast_scalar5"(%101, %102) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>) %104 = stablehlo.transpose %97#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32> %105 = stablehlo.convert %104 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>> %106 = stablehlo.transpose %105, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %107 = stablehlo.fft %106, type = FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %108 = stablehlo.transpose %107, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>> %c_12 = stablehlo.constant dense<0> : tensor<i64> %c_13 = stablehlo.constant dense<0> : tensor<i64> %c_14 = stablehlo.constant dense<0> : tensor<i64> %109 = stablehlo.dynamic_slice %108, %c_12, %c_13, %c_14, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>> %110 = stablehlo.transpose %109, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %111 = stablehlo.reshape %110 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %112 = stablehlo.transpose %111, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %113 = stablehlo.transpose %112, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>> %cst_15 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>> %114 = stablehlo.transpose %10, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %115 = stablehlo.transpose %113, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %116 = stablehlo.dot_general %114, %115, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %117 = stablehlo.transpose %116, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>> %118 = stablehlo.transpose %117, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %cst_16 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>> %119 = stablehlo.transpose %118, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %120 = stablehlo.reshape %119 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %121 = stablehlo.transpose %120, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %122 = stablehlo.pad %121, %cst_16, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>> %123 = stablehlo.transpose %122, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %124 = stablehlo.fft %123, type = IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %125 = stablehlo.transpose %124, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>> %126 = stablehlo.real %125 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32> %127 = stablehlo.transpose %126, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32> %128 = stablehlo.transpose %103#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32> %129 = stablehlo.reshape %128 : (tensor<160x64xf32>) -> tensor<5x32x64xf32> %130 = stablehlo.transpose %129, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32> %131:3 = enzyme.batch @"+_broadcast_scalar6"(%130, %127) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>) %132:2 = enzyme.batch @gelu_broadcast_scalar2(%131#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>) %cst_17 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32> %133 = stablehlo.transpose %132#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32> %134 = stablehlo.reshape %133 : (tensor<5x32x64xf32>) -> tensor<160x64xf32> %135 = stablehlo.transpose %134, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32> %136 = stablehlo.dot_general %11, %135, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32> %137 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32> %138:3 = enzyme.batch @"+_broadcast_scalar7"(%136, %137) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>) %139 = stablehlo.transpose %132#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32> %140 = stablehlo.convert %139 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>> %141 = stablehlo.transpose %140, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %142 = stablehlo.fft %141, type = FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %143 = stablehlo.transpose %142, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>> %c_18 = stablehlo.constant dense<0> : tensor<i64> %c_19 = stablehlo.constant dense<0> : tensor<i64> %c_20 = stablehlo.constant dense<0> : tensor<i64> %144 = stablehlo.dynamic_slice %143, %c_18, %c_19, %c_20, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>> %145 = stablehlo.transpose %144, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %146 = stablehlo.reshape %145 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %147 = stablehlo.transpose %146, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %148 = stablehlo.transpose %147, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>> %cst_21 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>> %149 = stablehlo.transpose %13, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %150 = stablehlo.transpose %148, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %151 = stablehlo.dot_general %149, %150, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %152 = stablehlo.transpose %151, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>> %153 = stablehlo.transpose %152, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %cst_22 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>> %154 = stablehlo.transpose %153, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %155 = stablehlo.reshape %154 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>> %156 = stablehlo.transpose %155, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>> %157 = stablehlo.pad %156, %cst_22, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>> %158 = stablehlo.transpose %157, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %159 = stablehlo.fft %158, type = IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>> %160 = stablehlo.transpose %159, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>> %161 = stablehlo.real %160 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32> %162 = stablehlo.transpose %161, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32> %163 = stablehlo.transpose %138#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32> %164 = stablehlo.reshape %163 : (tensor<160x64xf32>) -> tensor<5x32x64xf32> %165 = stablehlo.transpose %164, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32> %166:3 = enzyme.batch @"+_broadcast_scalar8"(%165, %162) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>) %167:2 = enzyme.batch @gelu_broadcast_scalar3(%166#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>) %cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<128x160xf32> %168 = stablehlo.transpose %167#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32> %169 = stablehlo.reshape %168 : (tensor<5x32x64xf32>) -> tensor<160x64xf32> %170 = stablehlo.transpose %169, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32> %171 = stablehlo.dot_general %14, %170, contracting_dims = [1] x [0] : (tensor<128x64xf32>, tensor<64x160xf32>) -> tensor<128x160xf32> %172 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %173 = stablehlo.reshape %172 : (tensor<128xf32>) -> tensor<1x128xf32> %174 = stablehlo.transpose %173, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32> %175 = stablehlo.broadcast_in_dim %174, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x160xf32> %176:3 = enzyme.batch @"+_broadcast_scalar9"(%171, %175) {batch_shape = array<i64: 128, 160>} : (tensor<128x160xf32>, tensor<128x160xf32>) -> (tensor<128x160xf32>, tensor<128x160xf32>, tensor<128x160xf32>) %177:2 = enzyme.batch @gelu_broadcast_scalar4(%176#0) {batch_shape = array<i64: 128, 160>} : (tensor<128x160xf32>) -> (tensor<128x160xf32>, tensor<128x160xf32>) %cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<1x160xf32> %178 = stablehlo.transpose %177#0, dims = [1, 0] : (tensor<128x160xf32>) -> tensor<160x128xf32> %179 = stablehlo.reshape %178 : (tensor<160x128xf32>) -> tensor<160x128xf32> %180 = stablehlo.transpose %179, dims = [1, 0] : (tensor<160x128xf32>) -> tensor<128x160xf32> %181 = stablehlo.dot_general %16, %180, contracting_dims = [1] x [0] : (tensor<1x128xf32>, tensor<128x160xf32>) -> tensor<1x160xf32> %182 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<1xf32>) -> tensor<1x160xf32> %183:3 = enzyme.batch @"+_broadcast_scalar10"(%181, %182) {batch_shape = array<i64: 1, 160>} : (tensor<1x160xf32>, tensor<1x160xf32>) -> (tensor<1x160xf32>, tensor<1x160xf32>, tensor<1x160xf32>) %184 = stablehlo.transpose %183#0, dims = [1, 0] : (tensor<1x160xf32>) -> tensor<160x1xf32> %185 = stablehlo.reshape %184 : (tensor<160x1xf32>) -> tensor<5x32x1xf32> %186 = stablehlo.transpose %185, dims = [2, 1, 0] : (tensor<5x32x1xf32>) -> tensor<1x32x5xf32> %cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %187:2 = enzyme.batch @abs2_broadcast_scalar(%186) {batch_shape = array<i64: 1, 32, 5>} : (tensor<1x32x5xf32>) -> (tensor<1x32x5xf32>, tensor<1x32x5xf32>) %188 = stablehlo.reduce(%187#0 init: %cst_25) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<1x32x5xf32>, tensor<f32>) -> tensor<f32> %189 = stablehlo.transpose %188, dims = [] : (tensor<f32>) -> tensor<f32> %190 = stablehlo.transpose %0, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32> %191 = stablehlo.transpose %1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %192 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %193 = stablehlo.transpose %3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %194 = stablehlo.transpose %4, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %195 = stablehlo.transpose %5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %196 = stablehlo.transpose %6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %197 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %198 = stablehlo.transpose %8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %199 = stablehlo.transpose %9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %200 = stablehlo.transpose %10, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %201 = stablehlo.transpose %11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %202 = stablehlo.transpose %12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %203 = stablehlo.transpose %13, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %204 = stablehlo.transpose %14, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32> %205 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %206 = stablehlo.transpose %16, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32> %207 = stablehlo.transpose %17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32> %208 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32> return %189, %190, %191, %192, %193, %194, %195, %196, %197, %198, %199, %200, %201, %202, %203, %204, %205, %206, %207, %208 : tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32> } func.func @main(%arg0: tensor<2x64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64xf32>, %arg4: tensor<16x64x64xcomplex<f32>>, %arg5: tensor<64x64xf32>, %arg6: tensor<64xf32>, %arg7: tensor<16x64x64xcomplex<f32>>, %arg8: tensor<64x64xf32>, %arg9: tensor<64xf32>, %arg10: tensor<16x64x64xcomplex<f32>>, %arg11: tensor<64x64xf32>, %arg12: tensor<64xf32>, %arg13: tensor<16x64x64xcomplex<f32>>, %arg14: tensor<64x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x1xf32>, %arg17: tensor<1xf32>, %arg18: tensor<5x32x2xf32>) -> (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>) { %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32> %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %4 = stablehlo.transpose %arg4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %5 = stablehlo.transpose %arg5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %6 = stablehlo.transpose %arg6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %7 = stablehlo.transpose %arg7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %10 = stablehlo.transpose %arg10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %12 = stablehlo.transpose %arg12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %13 = stablehlo.transpose %arg13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %14 = stablehlo.transpose %arg14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32> %15 = stablehlo.transpose %arg15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %16 = stablehlo.transpose %arg16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32> %17 = stablehlo.transpose %arg17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32> %18 = stablehlo.transpose %arg18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32> %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x2xf32> %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32> %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32> %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32> %cst_3 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>> %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32> %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32> %cst_6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>> %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32> %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32> %cst_9 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>> %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32> %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32> %cst_12 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>> %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<128x64xf32> %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<128xf32> %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<1x128xf32> %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<1xf32> %cst_17 = stablehlo.constant dense<1.000000e+00> : tensor<f32> %19 = stablehlo.transpose %0, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32> %20 = stablehlo.transpose %1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %21 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %22 = stablehlo.transpose %3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %23 = stablehlo.transpose %4, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %24 = stablehlo.transpose %5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %25 = stablehlo.transpose %6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %26 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %27 = stablehlo.transpose %8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %28 = stablehlo.transpose %9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %29 = stablehlo.transpose %10, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %30 = stablehlo.transpose %11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %31 = stablehlo.transpose %12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %32 = stablehlo.transpose %13, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %33 = stablehlo.transpose %14, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32> %34 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %35 = stablehlo.transpose %16, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32> %36 = stablehlo.transpose %17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32> %37 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32> %38 = stablehlo.transpose %cst_17, dims = [] : (tensor<f32>) -> tensor<f32> %39 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32> %40 = stablehlo.transpose %cst_0, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %41 = stablehlo.transpose %cst_1, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %42 = stablehlo.transpose %cst_2, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %43 = stablehlo.transpose %cst_3, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %44 = stablehlo.transpose %cst_4, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %45 = stablehlo.transpose %cst_5, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %46 = stablehlo.transpose %cst_6, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %47 = stablehlo.transpose %cst_7, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %48 = stablehlo.transpose %cst_8, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %49 = stablehlo.transpose %cst_9, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %50 = stablehlo.transpose %cst_10, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %51 = stablehlo.transpose %cst_11, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %52 = stablehlo.transpose %cst_12, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %53 = stablehlo.transpose %cst_13, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32> %54 = stablehlo.transpose %cst_14, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %55 = stablehlo.transpose %cst_15, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32> %56 = stablehlo.transpose %cst_16, dims = [0] : (tensor<1xf32>) -> tensor<1xf32> %57:37 = enzyme.autodiff @"Const{typeof(sumabs2first)}(Main.sumabs2first)_autodiff"(%19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>, tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>) -> (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>) %58 = stablehlo.transpose %57#0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32> %59 = stablehlo.transpose %57#1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %60 = stablehlo.transpose %57#2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %61 = stablehlo.transpose %57#3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %62 = stablehlo.transpose %57#4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %63 = stablehlo.transpose %57#5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %64 = stablehlo.transpose %57#6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %65 = stablehlo.transpose %57#7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %66 = stablehlo.transpose %57#8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %67 = stablehlo.transpose %57#9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %68 = stablehlo.transpose %57#10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %69 = stablehlo.transpose %57#11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %70 = stablehlo.transpose %57#12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %71 = stablehlo.transpose %57#13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %72 = stablehlo.transpose %57#14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32> %73 = stablehlo.transpose %57#15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %74 = stablehlo.transpose %57#16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32> %75 = stablehlo.transpose %57#17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32> %76 = stablehlo.transpose %57#18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32> %77 = stablehlo.transpose %57#19, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32> %78 = stablehlo.transpose %57#20, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %79 = stablehlo.transpose %57#21, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %80 = stablehlo.transpose %57#22, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %81 = stablehlo.transpose %57#23, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %82 = stablehlo.transpose %57#24, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %83 = stablehlo.transpose %57#25, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %84 = stablehlo.transpose %57#26, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %85 = stablehlo.transpose %57#27, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %86 = stablehlo.transpose %57#28, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %87 = stablehlo.transpose %57#29, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %88 = stablehlo.transpose %57#30, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %89 = stablehlo.transpose %57#31, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %90 = stablehlo.transpose %57#32, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>> %91 = stablehlo.transpose %57#33, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32> %92 = stablehlo.transpose %57#34, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %93 = stablehlo.transpose %57#35, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32> %94 = stablehlo.transpose %57#36, dims = [0] : (tensor<1xf32>) -> tensor<1xf32> %95 = stablehlo.transpose %77, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32> %96 = stablehlo.transpose %78, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %97 = stablehlo.transpose %79, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %98 = stablehlo.transpose %80, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %99 = stablehlo.transpose %81, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %100 = stablehlo.transpose %82, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %101 = stablehlo.transpose %83, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %102 = stablehlo.transpose %84, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %103 = stablehlo.transpose %85, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %104 = stablehlo.transpose %86, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %105 = stablehlo.transpose %87, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %106 = stablehlo.transpose %88, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %107 = stablehlo.transpose %89, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %108 = stablehlo.transpose %90, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %109 = stablehlo.transpose %91, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32> %110 = stablehlo.transpose %92, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %111 = stablehlo.transpose %93, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32> %112 = stablehlo.transpose %94, dims = [0] : (tensor<1xf32>) -> tensor<1xf32> %113 = stablehlo.transpose %58, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32> %114 = stablehlo.transpose %59, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %115 = stablehlo.transpose %60, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %116 = stablehlo.transpose %61, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %117 = stablehlo.transpose %62, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %118 = stablehlo.transpose %63, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %119 = stablehlo.transpose %64, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %120 = stablehlo.transpose %65, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %121 = stablehlo.transpose %66, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %122 = stablehlo.transpose %67, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %123 = stablehlo.transpose %68, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %124 = stablehlo.transpose %69, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32> %125 = stablehlo.transpose %70, dims = [0] : (tensor<64xf32>) -> tensor<64xf32> %126 = stablehlo.transpose %71, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>> %127 = stablehlo.transpose %72, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32> %128 = stablehlo.transpose %73, dims = [0] : (tensor<128xf32>) -> tensor<128xf32> %129 = stablehlo.transpose %74, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32> %130 = stablehlo.transpose %75, dims = [0] : (tensor<1xf32>) -> tensor<1xf32> %131 = stablehlo.transpose %76, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32> return %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127, %128, %129, %130, %131 : tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32> } }
Error Message with debug build
enzymexlamlir-opt: external/llvm-project/mlir/lib/IR/Types.cpp:134: unsigned int mlir::Type::getIntOrFloatBitWidth() const: Assertion `isIntOrFloat() && "only integers and floats have a bitwidth"' failed. PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace. Stack dump: 0. Program arguments: ./bazel-bin/enzymexlamlir-opt --enzyme-hlo-opt --enzyme-batch --enzyme envs/fno.mlir Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it): 0 enzymexlamlir-opt 0x000064943eed9fce 1 enzymexlamlir-opt 0x000064943eeda383 2 enzymexlamlir-opt 0x000064943eed7a84 3 enzymexlamlir-opt 0x000064943eed9af3 4 libc.so.6 0x00007709d944c1d0 5 libc.so.6 0x00007709d94a53f4 6 libc.so.6 0x00007709d944c120 gsignal + 32 7 libc.so.6 0x00007709d94334c3 abort + 223 8 libc.so.6 0x00007709d94333df 9 libc.so.6 0x00007709d9444177 10 enzymexlamlir-opt 0x000064943eba05c0 11 enzymexlamlir-opt 0x000064943b65a7c6 12 enzymexlamlir-opt 0x000064943b5a3f09 13 enzymexlamlir-opt 0x000064943b6c2f28 14 enzymexlamlir-opt 0x000064943e5ccfeb 15 enzymexlamlir-opt 0x000064943e5cda42 16 enzymexlamlir-opt 0x000064943af1b4d8 17 enzymexlamlir-opt 0x000064943e5d0d27 18 enzymexlamlir-opt 0x000064943e5cd769 19 enzymexlamlir-opt 0x000064943e5be495 20 enzymexlamlir-opt 0x000064943e5bf631 21 enzymexlamlir-opt 0x000064943e5c094b 22 enzymexlamlir-opt 0x000064943af1b4d8 23 enzymexlamlir-opt 0x000064943e5c06cf 24 enzymexlamlir-opt 0x000064943e5bf8fc 25 enzymexlamlir-opt 0x000064943e5bfab6 26 enzymexlamlir-opt 0x000064943b3e5548 27 enzymexlamlir-opt 0x000064943b5b2f6a 28 enzymexlamlir-opt 0x000064943ea6f1f9 29 enzymexlamlir-opt 0x000064943ea72eda 30 enzymexlamlir-opt 0x000064943af1b4d8 31 enzymexlamlir-opt 0x000064943ea7895f 32 enzymexlamlir-opt 0x000064943ea6f617 33 enzymexlamlir-opt 0x000064943ea6f959 34 enzymexlamlir-opt 0x000064943ea7180a 35 enzymexlamlir-opt 0x000064943ea715f9 36 enzymexlamlir-opt 0x000064943aefe9fb 37 enzymexlamlir-opt 0x000064943aeff1dd 38 enzymexlamlir-opt 0x000064943aeff8d0 39 enzymexlamlir-opt 0x000064943af00ae7 40 enzymexlamlir-opt 0x000064943edc5a42 41 enzymexlamlir-opt 0x000064943edc5198 42 enzymexlamlir-opt 0x000064943aeffa5e 43 enzymexlamlir-opt 0x000064943aeffd6c 44 enzymexlamlir-opt 0x000064943aeffffd 45 enzymexlamlir-opt 0x000064943aeb2642 46 libc.so.6 0x00007709d9434e08 47 libc.so.6 0x00007709d9434ecc __libc_start_main + 140 48 enzymexlamlir-opt 0x000064943aeb20b5 [1] 100225 IOT instruction (core dumped) ./bazel-bin/enzymexlamlir-opt --enzyme-hlo-opt --enzyme-batch --enzyme
The text was updated successfully, but these errors were encountered:
xref: SciML/NeuralOperators.jl#52
Sorry, something went wrong.
No branches or pull requests
Unoptimized MLIR
Error Message with debug build
The text was updated successfully, but these errors were encountered: