Skip to content

Commit

Permalink
[TTIR] Remove TTIR operand_constraints (#1388)
Browse files Browse the repository at this point in the history
  • Loading branch information
jserbedzijaTT authored Dec 12, 2024
1 parent d2cd95c commit f22c416
Show file tree
Hide file tree
Showing 245 changed files with 754 additions and 1,462 deletions.
3 changes: 1 addition & 2 deletions docs/src/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,10 @@ simpler.
So what does MLIR look like, how does it work and get parsed? The
hierarchy of an MLIR Module is as shown:
```
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {tt.system_desc = #tt.system_desc<[<#tt.arch<wormhole_b0>, #tt.grid<8x8>>], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = tensor.empty() : tensor<64x128xf32>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
}
Expand Down
109 changes: 39 additions & 70 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
TT_OperandConstraintArrayAttr:$operand_constraints);
Variadic<AnyRankedTensor>:$outputs);
let results = (outs Variadic<AnyRankedTensor>:$results);
}

Expand All @@ -199,9 +198,9 @@ class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :

let builders =
[
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out, "ArrayAttr": $operand_constraints),
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out),
[{
build($_builder, $_state, {out.getType()}, {first, second, third}, out, operand_constraints);
build($_builder, $_state, {out.getType()}, {first, second, third}, out);
}]>
];
}
Expand All @@ -222,9 +221,9 @@ class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "ArrayAttr": $operand_constraints),
OpBuilder<(ins "Value": $in, "Value": $out),
[{
build($_builder, $_state, {out.getType()}, in, out, operand_constraints);
build($_builder, $_state, {out.getType()}, in, out);
}]>
];
}
Expand Down Expand Up @@ -408,14 +407,13 @@ class TTIR_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> tra

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
F32Attr:$parameter,
TT_OperandConstraintArrayAttr:$operand_constraints);
F32Attr:$parameter);

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "FloatAttr":$parameter, "ArrayAttr": $operand_constraints),
OpBuilder<(ins "Value": $in, "Value": $out, "FloatAttr":$parameter),
[{
build($_builder, $_state, {out.getType()}, {in}, {out}, parameter, operand_constraints);
build($_builder, $_state, {out.getType()}, {in}, {out}, parameter);
}]>
];
}
Expand Down Expand Up @@ -452,9 +450,9 @@ class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :

let builders =
[
OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out, "ArrayAttr": $operand_constraints),
OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out),
[{
build($_builder, $_state, {out.getType()}, {lhs, rhs}, out, operand_constraints);
build($_builder, $_state, {out.getType()}, {lhs, rhs}, out);
}]>
];
}
Expand Down Expand Up @@ -568,8 +566,7 @@ class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
BoolAttr:$keep_dim,
OptionalAttr<I32ArrayAttr>:$dim_arg,
TT_OperandConstraintArrayAttr:$operand_constraints);
OptionalAttr<I32ArrayAttr>:$dim_arg);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -636,8 +633,7 @@ def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -657,8 +653,7 @@ def TTIR_EmbeddingBackwardOp : TTIR_DPSOp<"embedding_backward"> {
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
AnyRankedTensor:$in_gradient,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -677,8 +672,7 @@ def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dimension,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dimension);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -698,8 +692,7 @@ def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> {
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim0,
SI32Attr:$dim1,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim1);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -718,8 +711,7 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -739,8 +731,7 @@ def TTIR_UpdateCacheOp : TTIR_DPSOp<"update_cache"> {
let arguments = (ins AnyRankedTensor:$cache,
AnyRankedTensor:$input,
AnyRankedTensor:$update_index,
I32Attr:$batch_offset,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32Attr:$batch_offset);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -759,8 +750,7 @@ def TTIR_FillCacheOp : TTIR_DPSOp<"fill_cache"> {

let arguments = (ins AnyRankedTensor:$cache,
AnyRankedTensor:$input,
I32Attr:$batch_offset,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32Attr:$batch_offset);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -779,8 +769,7 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I64ArrayAttr:$dimension,
TT_OperandConstraintArrayAttr:$operand_constraints);
I64ArrayAttr:$dimension);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -807,8 +796,7 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
SI32Attr:$padding_left,
SI32Attr:$padding_right,
SI32Attr:$padding_top,
SI32Attr:$padding_bottom,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$padding_bottom);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -841,8 +829,7 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
DenseBoolArrayAttr:$window_reversal,
TTIR_ConvolutionLayoutAttr:$convolution_layout,
ConfinedAttr<I64Attr, [IntPositive]>:$feature_group_count,
ConfinedAttr<I64Attr, [IntPositive]>:$batch_group_count,
TT_OperandConstraintArrayAttr:$operand_constraints
ConfinedAttr<I64Attr, [IntPositive]>:$batch_group_count
);

let results = (outs AnyRankedTensor);
Expand All @@ -869,8 +856,7 @@ def TTIR_GatherOp: TTIR_DPSOp<"gather"> {
DenseI64ArrayAttr:$start_index_map,
SI64Attr:$index_vector_dim,
DenseI64ArrayAttr:$slice_sizes,
BoolAttr:$indices_are_sorted,
TT_OperandConstraintArrayAttr:$operand_constraints);
BoolAttr:$indices_are_sorted);
let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
Expand All @@ -891,8 +877,7 @@ def TTIR_PoolingOp : TTIR_DPSOp<"pooling", [AttrSizedOperandSegments]> {
DenseI64ArrayAttr:$window_strides,
DenseI64ArrayAttr:$base_dilations,
DenseI64ArrayAttr:$window_dilations,
DenseI64ArrayAttr:$padding,
TT_OperandConstraintArrayAttr:$operand_constraints
DenseI64ArrayAttr:$padding
);

let results = (outs Variadic<AnyRankedTensor>);
Expand All @@ -918,8 +903,7 @@ def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
SI32Attr:$padding_left,
SI32Attr:$padding_right,
SI32Attr:$padding_top,
SI32Attr:$padding_bottom,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$padding_bottom);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -938,8 +922,7 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I32ArrayAttr:$shape,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32ArrayAttr:$shape);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -964,8 +947,7 @@ def TTIR_SliceOp: TTIR_DPSOp<"slice"> {
AnyRankedTensor:$output,
I32ArrayAttr:$begins,
I32ArrayAttr:$ends,
I32ArrayAttr:$step,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32ArrayAttr:$step);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -991,8 +973,7 @@ def TTIR_SelectOp: TTIR_DPSOp<"select"> {
SI32Attr:$dim,
SI32Attr:$begin,
SI32Attr:$length,
DefaultValuedOptionalAttr<SI32Attr, "0">:$stride,
TT_OperandConstraintArrayAttr:$operand_constraints);
DefaultValuedOptionalAttr<SI32Attr, "0">:$stride);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1017,8 +998,7 @@ def TTIR_IndexOp: TTIR_DPSOp<"index"> {
I32Attr:$dim,
I32Attr:$begin,
I32Attr:$end,
I32Attr:$step,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32Attr:$step);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1038,8 +1018,7 @@ def TTIR_SqueezeOp : TTIR_DPSOp<"squeeze"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1058,8 +1037,7 @@ def TTIR_UnsqueezeOp : TTIR_DPSOp<"unsqueeze"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -1087,8 +1065,7 @@ def TTIR_ClampOp : TTIR_DPSOp<"clamp"> {
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
F32Attr:$min,
F32Attr:$max,
TT_OperandConstraintArrayAttr:$operand_constraints);
F32Attr:$max);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
Expand Down Expand Up @@ -1191,8 +1168,7 @@ def TTIR_FillOp : TTIR_DPSOp<"fill", [AllShapesMatch<["value", "result"]>]> {
}];

let arguments = (ins AnyRankedTensor:$output,
ElementsAttr:$value,
TT_OperandConstraintArrayAttr:$operand_constraints);
ElementsAttr:$value);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1217,8 +1193,7 @@ def TTIR_LinearOp : TTIR_DPSOp<"linear"> {
let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1238,8 +1213,7 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {

let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -1362,8 +1336,7 @@ def TTIR_ScatterOp: TTIR_DPSOp<"scatter"> {
I32Attr:$index_vector_dim,
BoolAttr:$indices_are_sorted,
BoolAttr:$unique_indices,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let regions = (region SizedRegion<1>:$update_computation);

Expand Down Expand Up @@ -1391,8 +1364,7 @@ def TTIR_KernelOp : TTIR_DPSOp<"kernel", [AttrSizedOperandSegments]> {
let arguments = (ins FlatSymbolRefAttr:$op,
FlatSymbolRefAttr:$kind,
Variadic<AnyRankedTensorOrMemRef>:$inputs,
Variadic<AnyRankedTensorOrMemRef>:$outputs,
TT_OperandConstraintArrayAttr:$operand_constraints);
Variadic<AnyRankedTensorOrMemRef>:$outputs);
let results = (outs Variadic<AnyRankedTensorOrMemRef>:$results);
}

Expand All @@ -1417,8 +1389,7 @@ def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1442,8 +1413,7 @@ def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> {
SI32Attr:$dim,
OptionalAttr<SI32Attr>:$channel_handle,
UnitAttr:$use_global_device_ids,
TT_ReduceTypeAttr:$reduce_type,
TT_OperandConstraintArrayAttr:$operand_constraints
TT_ReduceTypeAttr:$reduce_type
);

let results = (outs Variadic<AnyRankedTensor>:$results);
Expand Down Expand Up @@ -1490,8 +1460,7 @@ def TTIR_MeshShardOp : TTIR_DPSOp<"mesh_shard"> {
AnyRankedTensor:$output,
TT_MeshShardTypeAttr:$shard_type,
TT_MeshShardDirectionAttr:$shard_direction,
TT_GridAttr:$shard_shape,
TT_OperandConstraintArrayAttr:$operand_constraints
TT_GridAttr:$shard_shape
);

let results = (outs AnyRankedTensor:$result);
Expand Down
10 changes: 0 additions & 10 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,6 @@ include "ttmlir/Dialect/TT/IR/TTOpsTypes.td"
def TTIROpInterface : OpInterface<"TTIROp"> {
let cppNamespace = "::mlir::tt::ttir";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the constraints on the operands of this operation.
}],
/*retTy=*/"::mlir::ArrayAttr",
/*methodName=*/"getOperandConstraints",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Get the device of the current scope.
Expand Down
Loading

0 comments on commit f22c416

Please sign in to comment.