Skip to content

Commit

Permalink
Minor runtime cleanup, update ternary.cpp to be generic (#1255)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT authored Nov 14, 2024
1 parent db38351 commit 30d6fff
Show file tree
Hide file tree
Showing 46 changed files with 161 additions and 162 deletions.
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/include/tt/runtime/ttnn/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_TYPES_H
#define TTNN_RUNTIME_TYPES_H
#ifndef TT_RUNTIME_TTNN_TYPES_H
#define TT_RUNTIME_TTNN_TYPES_H

#include "tt/runtime/detail/ttnn.h"

Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_UTILS_H
#define TTNN_RUNTIME_UTILS_H
#ifndef TT_RUNTIME_TTNN_UTILS_H
#define TT_RUNTIME_TTNN_UTILS_H

#include "flatbuffers/vector.h"
#include "ttmlir/Target/Common/types_generated.h"
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/ccl/all_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ALL_GATHER_H
#define TTNN_RUNTIME_ALL_GATHER_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CCL_ALL_GATHER_H
#define RUNTIME_LIB_TTNN_OPERATIONS_CCL_ALL_GATHER_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/ttnn/operations/context/get_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ calculateMeshOffset(const ::ttnn::MeshDevice &parentMesh,
}
}
}
throw std::runtime_error("Could not find any desired device in parent mesh");
LOG_FATAL("Could not find any desired device in parent mesh");
}

static std::shared_ptr<::ttnn::MeshDevice>
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/context/get_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_GET_DEVICE_H
#define TTNN_RUNTIME_GET_DEVICE_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CONTEXT_GET_DEVICE_H
#define RUNTIME_LIB_TTNN_OPERATIONS_CONTEXT_GET_DEVICE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/conv/conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_CONV2D_H
#define TTNN_RUNTIME_CONV2D_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CONV_CONV2D_H
#define RUNTIME_LIB_TTNN_OPERATIONS_CONV_CONV2D_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/ttnn/operations/creation/empty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context) {
} else if (config.numShards > 1) {
out = createEmptyOnMultiDevice(context, config, op->device());
} else {
throw std::invalid_argument("Unsupported num shards");
LOG_FATAL("Unsupported num shards");
}
utils::updateTensorPool(tensorPool, out, op->out()->global_id());
}
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/creation/empty.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_EMPTY_H
#define TTNN_RUNTIME_EMPTY_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CREATION_EMPTY_H
#define RUNTIME_LIB_TTNN_OPERATIONS_CREATION_EMPTY_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/ttnn/operations/creation/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context) {
} else if (config.numShards > 1) {
out = createFullOnMultiDevice(context, config, deviceRef);
} else {
throw std::invalid_argument("Unsupported num shards");
LOG_FATAL("Unsupported num shards");
}
utils::updateTensorPool(tensorPool, out, op->out()->global_id());
}
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/creation/full.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_FULL_H
#define TTNN_RUNTIME_FULL_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CREATION_FULL_H
#define RUNTIME_LIB_TTNN_OPERATIONS_CREATION_FULL_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/data_movement/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_CONCAT_H
#define TTNN_RUNTIME_CONCAT_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_CONCAT_H
#define RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_CONCAT_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/data_movement/reshape.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_RESHAPE_H
#define TTNN_RUNTIME_RESHAPE_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_RESHAPE_H
#define RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_RESHAPE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/data_movement/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_SLICE_H
#define TTNN_RUNTIME_SLICE_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_SLICE_H
#define RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_SLICE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/data_movement/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_TRANSPOSE_H
#define TTNN_RUNTIME_TRANSPOSE_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_TRANSPOSE_H
#define RUNTIME_LIB_TTNN_OPERATIONS_DATA_MOVEMENT_TRANSPOSE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/deletion/deallocate.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_DEALLOCATE_H
#define TTNN_RUNTIME_DEALLOCATE_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_DELETION_DEALLOCATE_H
#define RUNTIME_LIB_TTNN_OPERATIONS_DELETION_DEALLOCATE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
37 changes: 18 additions & 19 deletions runtime/lib/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,19 @@

namespace tt::runtime::ttnn::operations::binary {

static void runEltwiseBinaryOP(
static void runEltwiseBinaryOp(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<::ttnn::Tensor(
const std::function<::ttnn::Tensor(
const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<const ::ttnn::DataType> &,
const std::optional<::tt::tt_metal::MemoryConfig> &,
std::optional<::ttnn::Tensor>,
std::optional<::ttnn::operations::unary::FusedActivations>,
std::optional<::ttnn::operations::unary::UnaryWithParam>)>
ttnnOp) {
std::optional<::ttnn::operations::unary::UnaryWithParam>)> &ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs);
getEltwiseBinaryOpInputTensors(op, tensorPool, &lhs, &rhs);

::ttnn::DataType outputDataType = utils::getDataType(op->out());
::tt::tt_metal::MemoryConfig outputMemoryConfig =
Expand All @@ -39,59 +38,59 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
switch (op->type()) {
/* Eltwise Binary */
case ::tt::target::ttnn::EltwiseOpType::Add: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::add);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::add);
break;
}
case ::tt::target::ttnn::EltwiseOpType::LogicalAnd: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::logical_and);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::logical_and);
break;
}
case ::tt::target::ttnn::EltwiseOpType::LogicalOr: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::logical_or);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::logical_or);
break;
}
case ::tt::target::ttnn::EltwiseOpType::LogicalXor: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::logical_xor);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::logical_xor);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Multiply: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::multiply);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::multiply);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Subtract: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::subtract);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::subtract);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Equal: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::eq);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::eq);
break;
}
case ::tt::target::ttnn::EltwiseOpType::NotEqual: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::ne);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::ne);
break;
}
case ::tt::target::ttnn::EltwiseOpType::GreaterEqual: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::ge);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::ge);
break;
}
case ::tt::target::ttnn::EltwiseOpType::GreaterThan: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::gt);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::gt);
break;
}
case ::tt::target::ttnn::EltwiseOpType::LessEqual: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::le);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::le);
break;
}
case ::tt::target::ttnn::EltwiseOpType::LessThan: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::lt);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::lt);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Div: {
runEltwiseBinaryOP(op, tensorPool, ::ttnn::divide);
runEltwiseBinaryOp(op, tensorPool, ::ttnn::divide);
break;
}
default:
throw std::invalid_argument("Unsupported Eltwise Binary operation");
LOG_FATAL("Unsupported Eltwise Binary operation");
}
}

Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/eltwise/binary/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_BINARY_H
#define TTNN_RUNTIME_ELTWISE_BINARY_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_BINARY_BINARY_H
#define RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_BINARY_BINARY_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
20 changes: 9 additions & 11 deletions runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@

namespace tt::runtime::ttnn::operations::binary::composite {

static void runEltwiseBinaryCompositeOP(
static void runEltwiseBinaryCompositeOp(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::tt::tt_metal::MemoryConfig> &)>
ttnnOp) {
const std::function<::ttnn::Tensor(
const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::tt::tt_metal::MemoryConfig> &)> &ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOPInputTensors(op, tensorPool, &lhs, &rhs);
getEltwiseBinaryOpInputTensors(op, tensorPool, &lhs, &rhs);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());
Expand All @@ -31,20 +30,19 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Maximum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum);
runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::maximum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Minimum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum);
runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::minimum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Remainder: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::remainder);
runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::remainder);
break;
}
default:
throw std::invalid_argument(
"Unsupported Eltwise Binary Composite operation");
LOG_FATAL("Unsupported Eltwise Binary Composite operation");
}
}

Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H
#define TTNN_RUNTIME_ELTWISE_BINARY_COMPOSITE_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_BINARY_COMPOSITE_H
#define RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_BINARY_COMPOSITE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
30 changes: 20 additions & 10 deletions runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,36 @@
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/ternary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"

namespace tt::runtime::ttnn::operations::ternary {

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
if (op->type() != ::tt::target::ttnn::EltwiseOpType::Where) {
throw std::invalid_argument("Unsupported Eltwise Ternary operation");
}

ProgramTensorPool &tensorPool = context.getTensorPool();

static void runEltwiseTernaryWhereOp(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
const std::function<::ttnn::Tensor(
const ::ttnn::Tensor &, const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::tt::tt_metal::MemoryConfig> &)> &ttnnOp) {
::ttnn::Tensor *first = nullptr;
::ttnn::Tensor *second = nullptr;
::ttnn::Tensor *third = nullptr;
getEltwiseTernaryOPInputTensors(op, tensorPool, &first, &second, &third);
getEltwiseTernaryOpInputTensors(op, tensorPool, &first, &second, &third);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out =
::ttnn::where(*first, *second, *third, outputMemoryConfig);
::ttnn::Tensor out = ttnnOp(*first, *second, *third, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
case ::tt::target::ttnn::EltwiseOpType::Where: {
runEltwiseTernaryWhereOp(op, tensorPool, ::ttnn::where);
break;
}
default:
LOG_FATAL("Unsupported ternary operation");
}
}
} // namespace tt::runtime::ttnn::operations::ternary
4 changes: 2 additions & 2 deletions runtime/lib/ttnn/operations/eltwise/ternary/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H
#define TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H
#ifndef RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H
#define RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
Expand Down
Loading

0 comments on commit 30d6fff

Please sign in to comment.