mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[MLIR][DLTI] Enable types as keys in DLTI-query utils (#105995)
Enable support for query functions - including transform.dlti.query - to take types as keys. As the data layout specific attributes already supported types as keys, this change enables querying such attributes in the expected way.
This commit is contained in:
@@ -26,7 +26,7 @@ namespace mlir {
|
||||
namespace dlti {
|
||||
/// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
|
||||
/// query interface-implementing attrs, starting from attr obtained from `op`.
|
||||
FailureOr<Attribute> query(Operation *op, ArrayRef<StringAttr> keys,
|
||||
FailureOr<Attribute> query(Operation *op, ArrayRef<DataLayoutEntryKey> keys,
|
||||
bool emitError = false);
|
||||
} // namespace dlti
|
||||
} // namespace mlir
|
||||
|
||||
@@ -26,9 +26,10 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
|
||||
|
||||
A lookup is performed for the given `keys` at `target` op - or its closest
|
||||
interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
|
||||
returns an attribute for a key. If more than one key is provided, the lookup
|
||||
continues recursively, now on the returned attributes, with the condition
|
||||
that these implement the above interface. For example if the payload IR is
|
||||
returns an attribute for a key. Each key should be either a (quoted) string
|
||||
or a type. If more than one key is provided, the lookup continues
|
||||
recursively, now on the returned attributes, with the condition that these
|
||||
implement the above interface. For example if the payload IR is
|
||||
|
||||
```
|
||||
module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
|
||||
@@ -52,7 +53,7 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||
StrArrayAttr:$keys);
|
||||
ArrayAttr:$keys);
|
||||
let results = (outs TransformParamTypeInterface:$associated_attr);
|
||||
let assemblyFormat =
|
||||
"$keys `at` $target attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
@@ -424,8 +424,16 @@ getClosestQueryable(Operation *op) {
|
||||
return std::pair(queryable, op);
|
||||
}
|
||||
|
||||
FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
|
||||
bool emitError) {
|
||||
FailureOr<Attribute>
|
||||
dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
|
||||
if (keys.empty()) {
|
||||
if (emitError) {
|
||||
auto diag = op->emitError() << "target op of failed DLTI query";
|
||||
diag.attachNote(op->getLoc()) << "no keys provided to attempt query with";
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto [queryable, queryOp] = getClosestQueryable(op);
|
||||
Operation *reportOp = (queryOp ? queryOp : op);
|
||||
|
||||
@@ -438,6 +446,15 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto keyToStr = [](DataLayoutEntryKey key) -> std::string {
|
||||
std::string buf;
|
||||
llvm::TypeSwitch<DataLayoutEntryKey>(key)
|
||||
.Case<StringAttr, Type>( // The only two kinds of key we know of.
|
||||
[&](auto key) { llvm::raw_string_ostream(buf) << key; })
|
||||
.Default([](auto) { llvm_unreachable("unexpected entry key kind"); });
|
||||
return buf;
|
||||
};
|
||||
|
||||
Attribute currentAttr = queryable;
|
||||
for (auto &&[idx, key] : llvm::enumerate(keys)) {
|
||||
if (auto map = llvm::dyn_cast<DLTIQueryInterface>(currentAttr)) {
|
||||
@@ -446,17 +463,24 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
|
||||
if (emitError) {
|
||||
auto diag = op->emitError() << "target op of failed DLTI query";
|
||||
diag.attachNote(reportOp->getLoc())
|
||||
<< "key " << key << " has no DLTI-mapping per attr: " << map;
|
||||
<< "key " << keyToStr(key)
|
||||
<< " has no DLTI-mapping per attr: " << map;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
currentAttr = *maybeAttr;
|
||||
} else {
|
||||
if (emitError) {
|
||||
std::string commaSeparatedKeys;
|
||||
llvm::interleave(
|
||||
keys.take_front(idx), // All prior keys.
|
||||
[&](auto key) { commaSeparatedKeys += keyToStr(key); },
|
||||
[&]() { commaSeparatedKeys += ","; });
|
||||
|
||||
auto diag = op->emitError() << "target op of failed DLTI query";
|
||||
diag.attachNote(reportOp->getLoc())
|
||||
<< "got non-DLTI-queryable attribute upon looking up keys ["
|
||||
<< keys.take_front(idx) << "] at op";
|
||||
<< commaSeparatedKeys << "] at op";
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -33,7 +33,16 @@ void transform::QueryOp::getEffects(
|
||||
DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
|
||||
transform::TransformRewriter &rewriter, Operation *target,
|
||||
transform::ApplyToEachResultList &results, TransformState &state) {
|
||||
auto keys = SmallVector<StringAttr>(getKeys().getAsRange<StringAttr>());
|
||||
SmallVector<DataLayoutEntryKey> keys;
|
||||
for (Attribute key : getKeys()) {
|
||||
if (auto strKey = dyn_cast<StringAttr>(key))
|
||||
keys.push_back(strKey);
|
||||
else if (auto typeKey = dyn_cast<TypeAttr>(key))
|
||||
keys.push_back(typeKey.getValue());
|
||||
else
|
||||
return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: "
|
||||
"only StringAttr and TypeAttr are allowed");
|
||||
}
|
||||
|
||||
FailureOr<Attribute> result = dlti::query(target, keys, /*emitError=*/true);
|
||||
|
||||
|
||||
@@ -33,6 +33,14 @@
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@below {{repeated layout entry key: 'i32'}}
|
||||
"test.unknown_op"() { test.unknown_attr = #dlti.map<
|
||||
#dlti.dl_entry<i32, 42>,
|
||||
#dlti.dl_entry<i32, 42>
|
||||
>} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@below {{repeated layout entry key: 'i32'}}
|
||||
"test.unknown_op"() { test.unknown_attr = #dlti.dl_spec<
|
||||
#dlti.dl_entry<i32, 42>,
|
||||
|
||||
@@ -17,6 +17,60 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark @below {{i32 present in set : unit}}
|
||||
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, unit>>} {
|
||||
func.func private @f()
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
|
||||
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
|
||||
%param = transform.dlti.query [i32] at %module : (!transform.any_op) -> !transform.any_param
|
||||
transform.debug.emit_param_as_remark %param, "i32 present in set :" at %module : !transform.any_param, !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark @below {{associated attr 32 : i32}}
|
||||
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, #dlti.map<#dlti.dl_entry<"width_in_bits", 32 : i32>>>>} {
|
||||
func.func private @f()
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
|
||||
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
|
||||
%param = transform.dlti.query [i32,"width_in_bits"] at %module : (!transform.any_op) -> !transform.any_param
|
||||
transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark @below {{width in bits of i32 = 32 : i64}}
|
||||
// expected-remark @below {{width in bits of f64 = 64 : i64}}
|
||||
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>, #dlti.dl_entry<f64, 64>>>>} {
|
||||
func.func private @f()
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
|
||||
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
|
||||
%i32bits = transform.dlti.query ["width_in_bits",i32] at %module : (!transform.any_op) -> !transform.any_param
|
||||
%f64bits = transform.dlti.query ["width_in_bits",f64] at %module : (!transform.any_op) -> !transform.any_param
|
||||
transform.debug.emit_param_as_remark %i32bits, "width in bits of i32 =" at %module : !transform.any_param, !transform.any_op
|
||||
transform.debug.emit_param_as_remark %f64bits, "width in bits of f64 =" at %module : !transform.any_param, !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark @below {{associated attr 42 : i32}}
|
||||
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
|
||||
func.func private @f()
|
||||
@@ -336,6 +390,23 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// expected-note @below {{got non-DLTI-queryable attribute upon looking up keys [i32]}}
|
||||
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<i32, 32 : i32>>} {
|
||||
// expected-error @below {{target op of failed DLTI query}}
|
||||
func.func private @f()
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
|
||||
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{'transform.dlti.query' op failed to apply}}
|
||||
%param = transform.dlti.query [i32,"width_in_bits"] at %func : (!transform.any_op) -> !transform.any_param
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// expected-error @below {{target op of failed DLTI query}}
|
||||
// expected-note @below {{no DLTI-queryable attrs on target op or any of its ancestors}}
|
||||
@@ -353,6 +424,55 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// expected-note @below {{key i64 has no DLTI-mapping per attr: #dlti.map<#dlti.dl_entry<i32, 32 : i64>>}}
|
||||
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>>>>} {
|
||||
// expected-error @below {{target op of failed DLTI query}}
|
||||
func.func private @f()
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
|
||||
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{'transform.dlti.query' op failed to apply}}
|
||||
%param = transform.dlti.query ["width_in_bits",i64] at %func : (!transform.any_op) -> !transform.any_param
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
|
||||
func.func private @f()
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
|
||||
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{'transform.dlti.query' keys of wrong type: only StringAttr and TypeAttr are allowed}}
|
||||
%param = transform.dlti.query [1] at %funcs : (!transform.any_op) -> !transform.param<i64>
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"test.id", 42 : i32>>} {
|
||||
// expected-error @below {{target op of failed DLTI query}}
|
||||
// expected-note @below {{no keys provided to attempt query with}}
|
||||
func.func private @f()
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
|
||||
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{'transform.dlti.query' op failed to apply}}
|
||||
%param = transform.dlti.query [] at %func : (!transform.any_op) -> !transform.any_param
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
|
||||
func.func private @f()
|
||||
}
|
||||
|
||||
@@ -206,3 +206,18 @@ module attributes {
|
||||
"GPU": #dlti.target_device_spec<
|
||||
#dlti.dl_entry<"L1_cache_size_in_bytes", "128">>
|
||||
>} {}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: "test.op_with_dlti_map"() ({
|
||||
// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42 : i64>>}
|
||||
"test.op_with_dlti_map"() ({
|
||||
}) { dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42>> } : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: "test.op_with_dlti_map"() ({
|
||||
// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<i32, 42 : i64>>}
|
||||
"test.op_with_dlti_map"() ({
|
||||
}) { dlti.map = #dlti.map<#dlti.dl_entry<i32, 42>> } : () -> ()
|
||||
|
||||
Reference in New Issue
Block a user