[MLIR][DLTI] Enable types as keys in DLTI-query utils (#105995)

Enable support for query functions - including transform.dlti.query - to
take types as keys. As the data layout specific attributes already
supported types as keys, this change enables querying such attributes in
the expected way.
This commit is contained in:
Rolf Morel
2024-08-27 13:41:15 +02:00
committed by GitHub
parent f8b446086f
commit 063e0bd52a
7 changed files with 187 additions and 10 deletions

View File

@@ -26,7 +26,7 @@ namespace mlir {
namespace dlti {
/// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
/// query interface-implementing attrs, starting from attr obtained from `op`.
FailureOr<Attribute> query(Operation *op, ArrayRef<StringAttr> keys,
FailureOr<Attribute> query(Operation *op, ArrayRef<DataLayoutEntryKey> keys,
bool emitError = false);
} // namespace dlti
} // namespace mlir

View File

@@ -26,9 +26,10 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
A lookup is performed for the given `keys` at `target` op - or its closest
interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
returns an attribute for a key. If more than one key is provided, the lookup
continues recursively, now on the returned attributes, with the condition
that these implement the above interface. For example if the payload IR is
returns an attribute for a key. Each key should be either a (quoted) string
or a type. If more than one key is provided, the lookup continues
recursively, now on the returned attributes, with the condition that these
implement the above interface. For example if the payload IR is
```
module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
@@ -52,7 +53,7 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
}];
let arguments = (ins TransformHandleTypeInterface:$target,
StrArrayAttr:$keys);
ArrayAttr:$keys);
let results = (outs TransformParamTypeInterface:$associated_attr);
let assemblyFormat =
"$keys `at` $target attr-dict `:` functional-type(operands, results)";

View File

@@ -424,8 +424,16 @@ getClosestQueryable(Operation *op) {
return std::pair(queryable, op);
}
FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
bool emitError) {
FailureOr<Attribute>
dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
if (keys.empty()) {
if (emitError) {
auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(op->getLoc()) << "no keys provided to attempt query with";
}
return failure();
}
auto [queryable, queryOp] = getClosestQueryable(op);
Operation *reportOp = (queryOp ? queryOp : op);
@@ -438,6 +446,15 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
return failure();
}
auto keyToStr = [](DataLayoutEntryKey key) -> std::string {
std::string buf;
llvm::TypeSwitch<DataLayoutEntryKey>(key)
.Case<StringAttr, Type>( // The only two kinds of key we know of.
[&](auto key) { llvm::raw_string_ostream(buf) << key; })
.Default([](auto) { llvm_unreachable("unexpected entry key kind"); });
return buf;
};
Attribute currentAttr = queryable;
for (auto &&[idx, key] : llvm::enumerate(keys)) {
if (auto map = llvm::dyn_cast<DLTIQueryInterface>(currentAttr)) {
@@ -446,17 +463,24 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
if (emitError) {
auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(reportOp->getLoc())
<< "key " << key << " has no DLTI-mapping per attr: " << map;
<< "key " << keyToStr(key)
<< " has no DLTI-mapping per attr: " << map;
}
return failure();
}
currentAttr = *maybeAttr;
} else {
if (emitError) {
std::string commaSeparatedKeys;
llvm::interleave(
keys.take_front(idx), // All prior keys.
[&](auto key) { commaSeparatedKeys += keyToStr(key); },
[&]() { commaSeparatedKeys += ","; });
auto diag = op->emitError() << "target op of failed DLTI query";
diag.attachNote(reportOp->getLoc())
<< "got non-DLTI-queryable attribute upon looking up keys ["
<< keys.take_front(idx) << "] at op";
<< commaSeparatedKeys << "] at op";
}
return failure();
}

View File

@@ -33,7 +33,16 @@ void transform::QueryOp::getEffects(
DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results, TransformState &state) {
auto keys = SmallVector<StringAttr>(getKeys().getAsRange<StringAttr>());
SmallVector<DataLayoutEntryKey> keys;
for (Attribute key : getKeys()) {
if (auto strKey = dyn_cast<StringAttr>(key))
keys.push_back(strKey);
else if (auto typeKey = dyn_cast<TypeAttr>(key))
keys.push_back(typeKey.getValue());
else
return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: "
"only StringAttr and TypeAttr are allowed");
}
FailureOr<Attribute> result = dlti::query(target, keys, /*emitError=*/true);

View File

@@ -33,6 +33,14 @@
// -----
// expected-error@below {{repeated layout entry key: 'i32'}}
"test.unknown_op"() { test.unknown_attr = #dlti.map<
#dlti.dl_entry<i32, 42>,
#dlti.dl_entry<i32, 42>
>} : () -> ()
// -----
// expected-error@below {{repeated layout entry key: 'i32'}}
"test.unknown_op"() { test.unknown_attr = #dlti.dl_spec<
#dlti.dl_entry<i32, 42>,

View File

@@ -17,6 +17,60 @@ module attributes {transform.with_named_sequence} {
// -----
// expected-remark @below {{i32 present in set : unit}}
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, unit>>} {
func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
%param = transform.dlti.query [i32] at %module : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "i32 present in set :" at %module : !transform.any_param, !transform.any_op
transform.yield
}
}
// -----
// expected-remark @below {{associated attr 32 : i32}}
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, #dlti.map<#dlti.dl_entry<"width_in_bits", 32 : i32>>>>} {
func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
%param = transform.dlti.query [i32,"width_in_bits"] at %module : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
transform.yield
}
}
// -----
// expected-remark @below {{width in bits of i32 = 32 : i64}}
// expected-remark @below {{width in bits of f64 = 64 : i64}}
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>, #dlti.dl_entry<f64, 64>>>>} {
func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
%i32bits = transform.dlti.query ["width_in_bits",i32] at %module : (!transform.any_op) -> !transform.any_param
%f64bits = transform.dlti.query ["width_in_bits",f64] at %module : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %i32bits, "width in bits of i32 =" at %module : !transform.any_param, !transform.any_op
transform.debug.emit_param_as_remark %f64bits, "width in bits of f64 =" at %module : !transform.any_param, !transform.any_op
transform.yield
}
}
// -----
// expected-remark @below {{associated attr 42 : i32}}
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
func.func private @f()
@@ -336,6 +390,23 @@ module attributes {transform.with_named_sequence} {
// -----
// expected-note @below {{got non-DLTI-queryable attribute upon looking up keys [i32]}}
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<i32, 32 : i32>>} {
// expected-error @below {{target op of failed DLTI query}}
func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
// expected-error @below {{'transform.dlti.query' op failed to apply}}
%param = transform.dlti.query [i32,"width_in_bits"] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}
// -----
module {
// expected-error @below {{target op of failed DLTI query}}
// expected-note @below {{no DLTI-queryable attrs on target op or any of its ancestors}}
@@ -353,6 +424,55 @@ module attributes {transform.with_named_sequence} {
// -----
// expected-note @below {{key i64 has no DLTI-mapping per attr: #dlti.map<#dlti.dl_entry<i32, 32 : i64>>}}
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>>>>} {
// expected-error @below {{target op of failed DLTI query}}
func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
// expected-error @below {{'transform.dlti.query' op failed to apply}}
%param = transform.dlti.query ["width_in_bits",i64] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}
// -----
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
// expected-error @below {{'transform.dlti.query' keys of wrong type: only StringAttr and TypeAttr are allowed}}
%param = transform.dlti.query [1] at %funcs : (!transform.any_op) -> !transform.param<i64>
transform.yield
}
}
// -----
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"test.id", 42 : i32>>} {
// expected-error @below {{target op of failed DLTI query}}
// expected-note @below {{no keys provided to attempt query with}}
func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
// expected-error @below {{'transform.dlti.query' op failed to apply}}
%param = transform.dlti.query [] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}
// -----
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
func.func private @f()
}

View File

@@ -206,3 +206,18 @@ module attributes {
"GPU": #dlti.target_device_spec<
#dlti.dl_entry<"L1_cache_size_in_bytes", "128">>
>} {}
// -----
// CHECK: "test.op_with_dlti_map"() ({
// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42 : i64>>}
"test.op_with_dlti_map"() ({
}) { dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42>> } : () -> ()
// -----
// CHECK: "test.op_with_dlti_map"() ({
// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<i32, 42 : i64>>}
"test.op_with_dlti_map"() ({
}) { dlti.map = #dlti.map<#dlti.dl_entry<i32, 42>> } : () -> ()