Skip to content

Commit

Permalink
tfl const op
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelpoluektov committed Dec 10, 2024
1 parent 23280cc commit eca04ac
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions xformer/Transforms/OptimizeTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ struct FoldDoubleTransposePattern : public OpRewritePattern<TFL::TransposeOp> {
{static_cast<int64_t>(permVec.size())}, rewriter.getIntegerType(32));

auto permAttr = DenseIntElementsAttr::get(permType, permVec);
auto permConstOp = rewriter.create<arith::ConstantOp>(transposeOp.getLoc(),
permType, permAttr);
auto permConstOp =
rewriter.create<TFL::ConstOp>(transposeOp.getLoc(), permType, permAttr);

// Create new transposeOp
auto newTransposeOp = rewriter.create<TFL::TransposeOp>(
Expand Down Expand Up @@ -562,8 +562,8 @@ struct MoveTransposeForwardOverConcatOpPattern
auto permType = RankedTensorType::get(
{static_cast<int64_t>(permVec.size())}, rewriter.getIntegerType(32));
auto permAttr = DenseIntElementsAttr::get(permType, permVec);
auto permConstOp = rewriter.create<arith::ConstantOp>(concatOp.getLoc(),
permType, permAttr);
auto permConstOp =
rewriter.create<TFL::ConstOp>(concatOp.getLoc(), permType, permAttr);

// Create the new TransposeOp with the original output type
auto newTransposeOp = rewriter.create<TFL::TransposeOp>(
Expand Down Expand Up @@ -631,7 +631,7 @@ struct HoistTransposeWCHAbovePadPattern
std::vector<int32_t> paddingValues{0, 0, 1, 1, 1, 1, 0, 0};
auto paddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), paddingValues);
auto paddingOp = rewriter.create<arith::ConstantOp>(
auto paddingOp = rewriter.create<TFL::ConstOp>(
padOp->getLoc(), RankedTensorType::get({4, 2}, rewriter.getI32Type()),
paddingAttr);
auto newPad = rewriter.create<TFL::PadOp>(
Expand Down

0 comments on commit eca04ac

Please sign in to comment.