Skip to content

Commit

Permalink
#sdy add option to avoid escaping attribute when adding to frontend a…
Browse files Browse the repository at this point in the history
…ttrs.

PiperOrigin-RevId: 705515578
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Dec 12, 2024
1 parent 33f696b commit fb8e7d5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
21 changes: 13 additions & 8 deletions xla/service/spmd/shardy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,18 @@ DictionaryAttr getFuncArgFrontendAttrs(FuncOp funcOp, unsigned int index) {

namespace {

mlir::StringAttr getStringAttribute(Attribute attr, mlir::OpBuilder& builder) {
mlir::StringAttr getStringAttribute(Attribute attr, mlir::OpBuilder& builder,
bool escapeAttr) {
std::string value;
if (auto stringAttr = mlir::dyn_cast<StringAttr>(attr)) {
if (!escapeAttr) {
return stringAttr;
}
value = stringAttr.getValue().str();
} else {
value = mlir::sdy::attributeToString(attr);
}
return builder.getStringAttr(absl::CEscape(value));
return builder.getStringAttr(escapeAttr ? absl::CEscape(value) : value);
}

SmallVector<NamedAttribute> getExistingFrontendAttributes(
Expand All @@ -87,9 +91,9 @@ SmallVector<NamedAttribute> getExistingFrontendAttributes(
}

void setFrontendAttribute(SmallVector<NamedAttribute>& existingAttributes,
StringRef name, Attribute value) {
StringRef name, Attribute value, bool escapeAttr) {
mlir::OpBuilder builder(value.getContext());
StringAttr stringValue = getStringAttribute(value, builder);
StringAttr stringValue = getStringAttribute(value, builder, escapeAttr);
for (auto* it = existingAttributes.begin(); it != existingAttributes.end();
++it) {
if (it->getName() == name) {
Expand Down Expand Up @@ -130,19 +134,20 @@ void setFuncArgFrontendAttrs(FuncOp funcOp, unsigned int index,

} // namespace

void setFrontendAttribute(Operation* op, StringRef name, Attribute value) {
void setFrontendAttribute(Operation* op, StringRef name, Attribute value,
bool escapeAttr) {
SmallVector<NamedAttribute> existingAttributes =
getExistingFrontendAttributes(getFrontendAttrs(op), "");
setFrontendAttribute(existingAttributes, name, value);
setFrontendAttribute(existingAttributes, name, value, escapeAttr);
setFrontendAttrs(op, existingAttributes);
}

void setFrontendAttribute(FuncOp funcOp, StringRef name, Attribute value,
int64_t argNum) {
int64_t argNum, bool escapeAttr) {
SmallVector<NamedAttribute> existingAttributes =
getExistingFrontendAttributes(getFuncArgFrontendAttrs(funcOp, argNum),
"");
setFrontendAttribute(existingAttributes, name, value);
setFrontendAttribute(existingAttributes, name, value, escapeAttr);
setFuncArgFrontendAttrs(funcOp, argNum, existingAttributes);
}

Expand Down
5 changes: 3 additions & 2 deletions xla/service/spmd/shardy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ mlir::DictionaryAttr getFuncArgFrontendAttrs(mlir::func::FuncOp funcOp,
// `name` already exists, it will be overwritten. Note that `value` will be
// turned into a `StringAttr`.
void setFrontendAttribute(mlir::Operation* op, mlir::StringRef name,
mlir::Attribute value);
mlir::Attribute value, bool escapeAttr = true);

// Adds `name` into the argument at `argNum`'s frontend attributes of `funcOp`
// with value `value`. If `name` already exists, it will be overwritten. Note
// that `value` will be turned into a `StringAttr`.
void setFrontendAttribute(mlir::func::FuncOp funcOp, mlir::StringRef name,
mlir::Attribute value, int64_t argNum);
mlir::Attribute value, int64_t argNum,
bool escapeAttr = true);

// Remove `attributeName` from the frontend attributes of `op`.
void removeFrontendAttribute(mlir::Operation* op,
Expand Down

0 comments on commit fb8e7d5

Please sign in to comment.