Skip to content

Commit

Permalink
Reverting tilizing f32 on device (#1669)
Browse files Browse the repository at this point in the history
Reverting commit ad87a44 and PR
#1647 that introduces llama
matmul test failure in forge-fe:
tenstorrent/tt-forge-fe#954 -
https://github.com/tenstorrent/tt-forge-fe/actions/runs/12501472126/job/34879043368?pr=954

Added a matmul repro test that fails before the revert, and passes after
the revert.
  • Loading branch information
sdjordjevicTT authored Dec 26, 2024
1 parent cfc6f53 commit 6d04d25
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 74 deletions.
78 changes: 33 additions & 45 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ namespace mlir::tt::ttnn {
#define GEN_PASS_DEF_TTNNMODIFYSIGNATURESFORDYLIB
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc"

// TTNN supports device tilize for bf16 and fp32
static bool canTilizeDataTypeOnDevice(DataType dataType) {
return dataType == DataType::BFloat16 or dataType == DataType::Float32;
}

class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {

public:
Expand Down Expand Up @@ -432,9 +427,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize, and the data type can be tilized on device, tilize
* on device */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(output.dataType)) {
/* If we should tilize, and the data type is bfloat16, we can tilize on
* device */
if (info.shouldTilize() and output.dataType == DataType::BFloat16) {
currentInput =
this->createToDeviceOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -445,10 +440,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize, and the data type cannot be tilized on device,
* tilize on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(output.dataType)) {
/* If we should tilize, and the data type is not bfloat16, we tilize on host
*/
if (info.shouldTilize() and output.dataType != DataType::BFloat16) {
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand Down Expand Up @@ -519,9 +513,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we need to tilize and the input datatype is tilizeable on device,
/* If we need to tilize and the input datatype is bfloat16
we can tilize on device and then typecast afterwards */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) {
if (info.shouldTilize() and input.dataType == DataType::BFloat16) {
currentInput =
this->createToDeviceOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -534,9 +528,9 @@ class TTNNDecomposeLayouts
return;
}

/* if we need to tilize and the output data type can be tilized on device,
/* if we need to tilize and the output data type is bfloat16
we can typecast on host and tilize on device */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(output.dataType)) {
if (info.shouldTilize() and output.dataType == DataType::BFloat16) {
currentInput =
this->createTypecastOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -549,11 +543,10 @@ class TTNNDecomposeLayouts
return;
}

/* if we need to tilize and the input/output data types cannot be tilized on
* device, do everything on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
not canTilizeDataTypeOnDevice(output.dataType)) {
/* if we need to tilize and the input/ output data types are not bfloat16 do
* everything on host */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
output.dataType != DataType::BFloat16) {
currentInput =
this->createTypecastOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand Down Expand Up @@ -646,10 +639,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type can be tilized on device,
* tilize on device
/* If we should tilize and the input data type is bfloat16, tilize on device
*/
if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) {
if (info.shouldTilize() and input.dataType == DataType::BFloat16) {
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
currentInput = this->createToMemoryConfigOpIfNeeded(op, rewriter,
Expand All @@ -660,10 +652,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type cannot be tilized on device,
* tilize on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
/* If we should tilize and the input data type is not bfloat16, tilize on
* host */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
opsToCreate.createFromDeviceOp) {
currentInput =
this->createFromDeviceOpIfNeeded(op, rewriter, currentInput, info);
Expand All @@ -673,10 +664,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we want to tilize a device tensor whose data type cannot be tilized on
* device, we need to tilize on host and move it back */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
/* If we want to tilize a device tensor that is not bfloat16, we need to
* tilize on host and move it back */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
not opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput =
Expand Down Expand Up @@ -791,9 +781,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type can be tilized on device,
* tilize and typecast on device */
if (info.shouldTilize() and canTilizeDataTypeOnDevice(input.dataType)) {
/* If we should tilize and the input data type is bfloat16, tilize and
* typecast on device */
if (info.shouldTilize() and input.dataType == DataType::BFloat16) {
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
currentInput =
Expand All @@ -806,10 +796,9 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type cannot be tilized on device,
and we want to read back from device, do everything on host */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
/* If we should tilize and the input data type is not bfloat16 and we want
to read back from device do everything on host */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
opsToCreate.createFromDeviceOp) {
currentInput =
this->createFromDeviceOpIfNeeded(op, rewriter, currentInput, info);
Expand All @@ -821,11 +810,10 @@ class TTNNDecomposeLayouts
return;
}

/* If we should tilize and the input data type cannot be tilized on device,
and we don't want to read back from device - tilize on host, move back to
device, and typecast on device */
if (info.shouldTilize() and
not canTilizeDataTypeOnDevice(input.dataType) and
/* If we should tilize and the input data type is not bfloat 16 and we don't
want to read back from device: tilize on host, move back to device, and
typecast on device */
if (info.shouldTilize() and input.dataType != DataType::BFloat16 and
not opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput =
Expand Down Expand Up @@ -875,7 +863,7 @@ class TTNNDecomposeLayouts
/*
* Logic for creating ops. Conditions/constraints include:
* - When possible, we want to execute operations on device.
* - Tilize on device requires dataformat of BFLOAT16 or FLOAT32.
* - Tilize on device requires dataformat of BFLOAT16.
* - Typecast on device requires TILIZED tensor.
* - Untilize on device requires even width, and page size >
* sizeof(uint32_t). For now, we will always untilize on host. We rarely
Expand Down
35 changes: 15 additions & 20 deletions runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

namespace tt::runtime::ttnn {

static bool canTilizeDataTypeOnDevice(::ttnn::DataType dataType) {
return dataType == ::ttnn::DataType::BFLOAT16 or
dataType == ::ttnn::DataType::FLOAT32;
}
//
// LayoutConverter APIs
//
Expand Down Expand Up @@ -107,14 +103,14 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutNoTypecast(
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) {
if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice);
out = toLayoutIfNeeded(out);
out = toMemoryConfigIfNeeded(out);
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) {
if (shouldTilize and outputDesc.dataType != ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toLayoutIfNeeded(input);
out = toDeviceIfNeeded(out, targetDevice);
out = toMemoryConfigIfNeeded(out);
Expand Down Expand Up @@ -151,24 +147,24 @@ ::ttnn::Tensor LayoutConverter::handleHostInputLayoutTypecast(
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) {
if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice);
out = toLayoutIfNeeded(out);
out = typecastIfNeeded(out);
out = toMemoryConfigIfNeeded(out);
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(outputDesc.dataType)) {
if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = typecastIfNeeded(input);
out = toDeviceIfNeeded(out, targetDevice);
out = toLayoutIfNeeded(input);
out = toMemoryConfigIfNeeded(out);
return out;
}

if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
not canTilizeDataTypeOnDevice(outputDesc.dataType)) {
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
outputDesc.dataType != ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = typecastIfNeeded(input);
out = toLayoutIfNeeded(out);
out = toDeviceIfNeeded(out, targetDevice);
Expand Down Expand Up @@ -221,26 +217,25 @@ ::ttnn::Tensor LayoutConverter::handleDeviceInputLayoutNoTypecast(
return out;
}

/* If we should tilize and the input data type can be tilized on device,
* tilize on device
/* If we should tilize and the input data type is bfloat16, tilize on device
*/
if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) {
if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toLayoutIfNeeded(input);
out = toMemoryConfigIfNeeded(out);
out = fromDeviceIfNeeded(out);
return out;
}

/* If we should tilize and the input data type cannot be tilized on device,
* tilize on host */
if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
/* If we should tilize and the input data type is not bfloat16, tilize on
* host */
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
shouldFromDevice) {
::ttnn::Tensor out = fromDeviceIfNeeded(input);
out = toLayoutIfNeeded(out);
return out;
}

if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
not shouldFromDevice) {
LOG_WARNING("Currently no constraint checking for on-device tilize.");
::ttnn::Tensor out = toLayoutIfNeeded(input);
Expand Down Expand Up @@ -292,23 +287,23 @@ LayoutConverter::handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input) {
return out;
}

if (shouldTilize and canTilizeDataTypeOnDevice(inputDesc.dataType)) {
if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) {
::ttnn::Tensor out = toLayoutIfNeeded(input);
out = typecastIfNeeded(out);
out = toMemoryConfigIfNeeded(out);
out = fromDeviceIfNeeded(out);
return out;
}

if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
shouldFromDevice) {
::ttnn::Tensor out = fromDeviceIfNeeded(input);
out = toLayoutIfNeeded(out);
out = typecastIfNeeded(out);
return out;
}

if (shouldTilize and not canTilizeDataTypeOnDevice(inputDesc.dataType) and
if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and
not shouldFromDevice) {
LOG_WARNING("Currently no constraint checking for on-device tilize.");
::ttnn::Tensor out = toLayoutIfNeeded(input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
#dram = #ttnn.buffer_type<dram>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #system_memory>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #dram>, <interleaved>>

module attributes {tt.device = #device} {
func.func @add(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x4>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1>
%4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<64x128>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2>
%4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout>
%6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<row_major>}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout>
return %6 : tensor<64x128xbf16, #ttnn_layout>
}
Expand All @@ -26,9 +27,9 @@ module attributes {tt.device = #device} {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x4>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1>
%4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<64x128>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2>
%4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout>
%6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<row_major>}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout>
return %6 : tensor<64x128xbf16, #ttnn_layout>
}
Expand All @@ -39,9 +40,9 @@ module attributes {tt.device = #device} {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout<tile>}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x4>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout1>
%4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<64x128>>, <interleaved>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2>
%4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2>
%5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout>
%6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<row_major>}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout>
return %6 : tensor<64x128xbf16, #ttnn_layout>
}
Expand Down
11 changes: 11 additions & 0 deletions test/ttmlir/Silicon/TTNN/matmul/llama_matmul.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module attributes {} {
func.func @forward(%arg0: tensor<1x11x2048xf32>, %arg1: tensor<2048x128256xf32>) -> tensor<1x11x128256xf32> {
%0 = tensor.empty() : tensor<1x11x128256xf32>
// CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]]
%1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x11x2048xf32>, tensor<2048x128256xf32>, tensor<1x11x128256xf32>) -> tensor<1x11x128256xf32>
return %1 : tensor<1x11x128256xf32>
}
}
File renamed without changes.

0 comments on commit 6d04d25

Please sign in to comment.