mirror of
https://github.com/intel/llvm.git
synced 2026-01-31 07:27:33 +08:00
[mlir][transform] Allow arbitrary indices to be scalable
This change lifts the limitation that only the trailing dimensions/sizes
in dynamic index lists can be scalable. It allows us to extend
`MaskedVectorizeOp` and `TileOp` from the Transform dialect so that the
following is allowed:
%1, %loops:3 = transform.structured.tile %0 [4, [4], [4]]
This is also a follow up for https://reviews.llvm.org/D153372
that will enable the following (middle vector dimension is scalable):
transform.structured.masked_vectorize %0 vector_sizes [2, [4], 8]
To facilate this change, the hooks for parsing and printing dynamic
index lists are updated accordingly (`printDynamicIndexList` and
`parseDynamicIndexList`, respectively). `MaskedVectorizeOp` and `TileOp`
are updated to include an array of attribute of bools that captures
whether the corresponding vector dimension/tile size, respectively, are
scalable or not.
NOTE 1: I am re-landing this after the initial version was reverted. To
fix the regression and in addition to the original patch, this revision
updates the Python bindings for the transform dialect
NOTE 2: This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
* https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/
This relands 048764f23a with fixes.
Differential Revision: https://reviews.llvm.org/D154336
This commit is contained in:
@@ -1690,7 +1690,7 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
|
||||
Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
|
||||
DefaultValuedOptionalAttr<BoolAttr, "false">:$last_tile_size_scalable);
|
||||
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
|
||||
let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
|
||||
Variadic<TransformHandleTypeInterface>:$loops);
|
||||
let builders = [
|
||||
@@ -2012,9 +2012,10 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
|
||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||
Variadic<TransformHandleTypeInterface>:$vector_sizes,
|
||||
UnitAttr:$vectorize_nd_extract,
|
||||
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
|
||||
$scalable_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
|
||||
$static_vector_sizes,
|
||||
DefaultValuedOptionalAttr<BoolAttr, "false">:$last_vector_size_scalable);
|
||||
$static_vector_sizes);
|
||||
|
||||
let results = (outs);
|
||||
let assemblyFormat = [{
|
||||
@@ -2022,7 +2023,7 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
|
||||
`vector_sizes` custom<DynamicIndexList>($vector_sizes,
|
||||
$static_vector_sizes,
|
||||
type($vector_sizes),
|
||||
$last_vector_size_scalable)
|
||||
$scalable_sizes)
|
||||
attr-dict
|
||||
`:` type($target)
|
||||
}];
|
||||
|
||||
@@ -52,13 +52,15 @@ namespace mlir {
|
||||
/// integer attributes in a list. E.g.
|
||||
/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
|
||||
///
|
||||
/// If `isTrailingIdxScalable` is true, then wrap the trailing index with
|
||||
/// square brackets, e.g. `[42]`, to denote scalability. This would normally be
|
||||
/// used for scalable tile or vector sizes.
|
||||
/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
|
||||
/// This notation is similar to how scalable dims are marked when defining
|
||||
/// Vectors. For each value in `integers`, the corresponding `bool` in
|
||||
/// `scalables` encodes whether it's a scalable index. If `scalableVals` is
|
||||
/// empty then assume that all indices are non-scalable.
|
||||
void printDynamicIndexList(
|
||||
OpAsmPrinter &printer, Operation *op, OperandRange values,
|
||||
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
|
||||
BoolAttr isTrailingIdxScalable = {},
|
||||
ArrayRef<bool> scalables = {},
|
||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
|
||||
|
||||
/// Parser hook for custom directive in assemblyFormat.
|
||||
@@ -78,41 +80,43 @@ void printDynamicIndexList(
|
||||
/// `kDynamic`]"
|
||||
/// 2. `ssa` is filled with "[%arg0, %arg1]".
|
||||
///
|
||||
/// Trailing indices can be scalable. For example, "42" in "[7, [42]]" is
|
||||
/// scalable. This notation is similar to how scalable dims are marked when
|
||||
/// defining Vectors. If /p isTrailingIdxScalable is null, scalable indices are
|
||||
/// not allowed/expected. When it's not null, this hook will set the
|
||||
/// corresponding value to:
|
||||
/// * true if the trailing idx is scalable,
|
||||
/// * false otherwise.
|
||||
/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
|
||||
/// This notation is similar to how scalable dims are marked when defining
|
||||
/// Vectors. For each value in `integers`, the corresponding `bool` in
|
||||
/// `scalableVals` encodes whether it's a scalable index.
|
||||
ParseResult parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable = nullptr,
|
||||
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
|
||||
SmallVectorImpl<Type> *valueTypes = nullptr,
|
||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
|
||||
inline ParseResult parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
|
||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
|
||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
|
||||
return parseDynamicIndexList(parser, values, integers,
|
||||
/*isTrailingIdxScalable=*/nullptr, &valueTypes,
|
||||
delimiter);
|
||||
DenseBoolArrayAttr scalableVals = {};
|
||||
return parseDynamicIndexList(parser, values, integers, scalableVals,
|
||||
valueTypes, delimiter);
|
||||
}
|
||||
inline ParseResult parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
|
||||
BoolAttr &isTrailingIdxScalable,
|
||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
|
||||
DenseBoolArrayAttr scalableVals = {};
|
||||
return parseDynamicIndexList(parser, values, integers, scalableVals,
|
||||
&valueTypes, delimiter);
|
||||
}
|
||||
inline ParseResult parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
|
||||
DenseBoolArrayAttr &scalableVals,
|
||||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
|
||||
|
||||
bool scalable = false;
|
||||
auto res = parseDynamicIndexList(parser, values, integers, &scalable,
|
||||
&valueTypes, delimiter);
|
||||
auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
|
||||
isTrailingIdxScalable = scalableAttr;
|
||||
return res;
|
||||
return parseDynamicIndexList(parser, values, integers, scalableVals,
|
||||
&valueTypes, delimiter);
|
||||
}
|
||||
|
||||
/// Verify that a the `values` has as many elements as the number of entries in
|
||||
|
||||
@@ -2451,7 +2451,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
|
||||
SmallVector<Operation *> tiled;
|
||||
SmallVector<SmallVector<Operation *, 4>, 4> loops;
|
||||
loops.resize(getLoops().size());
|
||||
bool scalable = getLastTileSizeScalable();
|
||||
auto scalableSizes = getScalableSizes();
|
||||
for (auto [i, op] : llvm::enumerate(targets)) {
|
||||
auto tilingInterface = dyn_cast<TilingInterface>(op);
|
||||
auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
|
||||
@@ -2470,12 +2470,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
|
||||
SmallVector<Value, 4> sizes;
|
||||
sizes.reserve(tileSizes.size());
|
||||
unsigned dynamicIdx = 0;
|
||||
unsigned trailingIdx = getMixedSizes().size() - 1;
|
||||
|
||||
for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
|
||||
// Only the trailing tile size is allowed to be scalable atm.
|
||||
if (scalable && (ofrIdx == trailingIdx)) {
|
||||
if (scalableSizes[ofrIdx]) {
|
||||
auto val = b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), attr.cast<IntegerAttr>().getInt());
|
||||
Value vscale =
|
||||
@@ -2577,9 +2575,10 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
DenseI64ArrayAttr staticSizes;
|
||||
FunctionType functionalType;
|
||||
llvm::SMLoc operandLoc;
|
||||
bool scalable = false;
|
||||
DenseBoolArrayAttr scalableVals;
|
||||
|
||||
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
|
||||
parseDynamicIndexList(parser, dynamicSizes, staticSizes, &scalable) ||
|
||||
parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
|
||||
parseOptionalInterchange(parser, result) ||
|
||||
parser.parseColonType(functionalType))
|
||||
return ParseResult::failure();
|
||||
@@ -2602,9 +2601,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
|
||||
result.addAttribute(getLastTileSizeScalableAttrName(result.name),
|
||||
scalableAttr);
|
||||
result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
|
||||
|
||||
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
|
||||
result.addTypes(functionalType.getResults());
|
||||
@@ -2614,7 +2611,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
void TileOp::print(OpAsmPrinter &p) {
|
||||
p << ' ' << getTarget();
|
||||
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
|
||||
/*valueTypes=*/{}, getLastTileSizeScalableAttr(),
|
||||
/*valueTypes=*/{}, getScalableSizesAttr(),
|
||||
OpAsmParser::Delimiter::Square);
|
||||
printOptionalInterchange(p, getInterchange());
|
||||
p << " : ";
|
||||
@@ -3161,15 +3158,14 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
|
||||
}
|
||||
|
||||
// TODO: Check that the correct number of vectorSizes was provided.
|
||||
SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
|
||||
scalableVecDims.back() = getLastVectorSizeScalable();
|
||||
for (Operation *target : targets) {
|
||||
if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
|
||||
return mlir::emitSilenceableFailure(target->getLoc())
|
||||
<< "Unsupported Op, cannot vectorize";
|
||||
}
|
||||
|
||||
if (failed(linalg::vectorize(rewriter, target, vectorSizes, scalableVecDims,
|
||||
if (failed(linalg::vectorize(rewriter, target, vectorSizes,
|
||||
getScalableSizes(),
|
||||
getVectorizeNdExtract()))) {
|
||||
return mlir::emitSilenceableFailure(target->getLoc())
|
||||
<< "Attempted to vectorize, but failed";
|
||||
|
||||
@@ -1254,20 +1254,20 @@ void ForallOp::print(OpAsmPrinter &p) {
|
||||
if (isNormalized()) {
|
||||
p << ") in ";
|
||||
printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
|
||||
/*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
|
||||
/*valueTypes=*/{}, /*scalables=*/{},
|
||||
OpAsmParser::Delimiter::Paren);
|
||||
} else {
|
||||
p << ") = ";
|
||||
printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
|
||||
/*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
|
||||
/*valueTypes=*/{}, /*scalables=*/{},
|
||||
OpAsmParser::Delimiter::Paren);
|
||||
p << " to ";
|
||||
printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
|
||||
/*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
|
||||
/*valueTypes=*/{}, /*scalables=*/{},
|
||||
OpAsmParser::Delimiter::Paren);
|
||||
p << " step ";
|
||||
printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
|
||||
/*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
|
||||
/*valueTypes=*/{}, /*scalable=*/{},
|
||||
OpAsmParser::Delimiter::Paren);
|
||||
}
|
||||
printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
|
||||
@@ -1299,9 +1299,9 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
dynamicSteps;
|
||||
if (succeeded(parser.parseOptionalKeyword("in"))) {
|
||||
// Parse upper bounds.
|
||||
if (parseDynamicIndexList(
|
||||
parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
|
||||
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||
if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
|
||||
/*valueTypes=*/nullptr,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
@@ -1311,26 +1311,26 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
} else {
|
||||
// Parse lower bounds.
|
||||
if (parser.parseEqual() ||
|
||||
parseDynamicIndexList(
|
||||
parser, dynamicLbs, staticLbs, /*isTrailingIdxScalable=*/nullptr,
|
||||
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||
parseDynamicIndexList(parser, dynamicLbs, staticLbs,
|
||||
/*valueTypes=*/nullptr,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
|
||||
parser.resolveOperands(dynamicLbs, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse upper bounds.
|
||||
if (parser.parseKeyword("to") ||
|
||||
parseDynamicIndexList(
|
||||
parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
|
||||
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||
parseDynamicIndexList(parser, dynamicUbs, staticUbs,
|
||||
/*valueTypes=*/nullptr,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse step values.
|
||||
if (parser.parseKeyword("step") ||
|
||||
parseDynamicIndexList(
|
||||
parser, dynamicSteps, staticSteps, /*scalable=*/nullptr,
|
||||
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
|
||||
parseDynamicIndexList(parser, dynamicSteps, staticSteps,
|
||||
/*valueTypes=*/nullptr,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(dynamicSteps, indexType, result.operands))
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -42,6 +42,5 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList(
|
||||
return success();
|
||||
}
|
||||
|
||||
return parseDynamicIndexList(parser, values, integers,
|
||||
/*isTrailingIdxScalable=*/nullptr, &valueTypes);
|
||||
return parseDynamicIndexList(parser, values, integers, &valueTypes);
|
||||
}
|
||||
|
||||
@@ -102,8 +102,7 @@ static char getRightDelimiter(AsmParser::Delimiter delimiter) {
|
||||
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
OperandRange values,
|
||||
ArrayRef<int64_t> integers,
|
||||
TypeRange valueTypes,
|
||||
BoolAttr isTrailingIdxScalable,
|
||||
TypeRange valueTypes, ArrayRef<bool> scalables,
|
||||
AsmParser::Delimiter delimiter) {
|
||||
char leftDelimiter = getLeftDelimiter(delimiter);
|
||||
char rightDelimiter = getRightDelimiter(delimiter);
|
||||
@@ -113,33 +112,24 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t trailingScalableInteger;
|
||||
if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
|
||||
// ATM only the trailing idx can be scalable
|
||||
trailingScalableInteger = integers.back();
|
||||
integers = integers.drop_back();
|
||||
}
|
||||
|
||||
unsigned idx = 0;
|
||||
unsigned dynamicValIdx = 0;
|
||||
unsigned scalableIndexIdx = 0;
|
||||
llvm::interleaveComma(integers, printer, [&](int64_t integer) {
|
||||
if (not scalables.empty() && scalables[scalableIndexIdx])
|
||||
printer << "[";
|
||||
if (ShapedType::isDynamic(integer)) {
|
||||
printer << values[idx];
|
||||
printer << values[dynamicValIdx];
|
||||
if (!valueTypes.empty())
|
||||
printer << " : " << valueTypes[idx];
|
||||
++idx;
|
||||
printer << " : " << valueTypes[dynamicValIdx];
|
||||
++dynamicValIdx;
|
||||
} else {
|
||||
printer << integer;
|
||||
}
|
||||
});
|
||||
if (!scalables.empty() && scalables[scalableIndexIdx])
|
||||
printer << "]";
|
||||
|
||||
// Print the trailing scalable index
|
||||
if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
|
||||
if (!integers.empty())
|
||||
printer << ", ";
|
||||
printer << "[";
|
||||
printer << trailingScalableInteger;
|
||||
printer << "]";
|
||||
}
|
||||
scalableIndexIdx++;
|
||||
});
|
||||
|
||||
printer << rightDelimiter;
|
||||
}
|
||||
@@ -147,25 +137,17 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
ParseResult mlir::parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable,
|
||||
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
|
||||
SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
|
||||
|
||||
SmallVector<int64_t, 4> integerVals;
|
||||
bool foundScalable = false;
|
||||
SmallVector<bool, 4> scalableVals;
|
||||
auto parseIntegerOrValue = [&]() {
|
||||
OpAsmParser::UnresolvedOperand operand;
|
||||
auto res = parser.parseOptionalOperand(operand);
|
||||
|
||||
// If `foundScalable` has already been set to `true` then a non-trailing
|
||||
// index was identified as scalable.
|
||||
if (foundScalable) {
|
||||
parser.emitError(parser.getNameLoc())
|
||||
<< "non-trailing index cannot be scalable";
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded())
|
||||
foundScalable = true;
|
||||
// When encountering `[`, assume that this is a scalable index.
|
||||
scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
|
||||
|
||||
if (res.has_value() && succeeded(res.value())) {
|
||||
values.push_back(operand);
|
||||
@@ -178,7 +160,10 @@ ParseResult mlir::parseDynamicIndexList(
|
||||
return failure();
|
||||
integerVals.push_back(integer);
|
||||
}
|
||||
if (foundScalable && parser.parseOptionalRSquare().failed())
|
||||
|
||||
// If this is assumed to be a scalable index, verify that there's a closing
|
||||
// `]`.
|
||||
if (scalableVals.back() && parser.parseOptionalRSquare().failed())
|
||||
return failure();
|
||||
return success();
|
||||
};
|
||||
@@ -187,8 +172,7 @@ ParseResult mlir::parseDynamicIndexList(
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< "expected SSA value or integer";
|
||||
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
|
||||
if (isTrailingIdxScalable)
|
||||
*isTrailingIdxScalable = foundScalable;
|
||||
scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ from typing import List, Optional, Sequence, Union, overload
|
||||
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
|
||||
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
|
||||
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
|
||||
|
||||
|
||||
def _get_int_int_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
|
||||
@@ -226,6 +229,7 @@ class TileOp:
|
||||
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
|
||||
] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
scalable_sizes: OptionalBoolList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
@@ -240,6 +244,7 @@ class TileOp:
|
||||
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
|
||||
] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
scalable_sizes: OptionalBoolList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
@@ -254,6 +259,7 @@ class TileOp:
|
||||
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
|
||||
] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
scalable_sizes: OptionalBoolList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
@@ -261,6 +267,8 @@ class TileOp:
|
||||
interchange = []
|
||||
if sizes is None:
|
||||
sizes = []
|
||||
if scalable_sizes is None:
|
||||
scalable_sizes = []
|
||||
|
||||
static_sizes = []
|
||||
dynamic_sizes = []
|
||||
@@ -298,6 +306,7 @@ class TileOp:
|
||||
dynamic_sizes=dynamic_sizes,
|
||||
static_sizes=sizes_attr,
|
||||
interchange=interchange,
|
||||
scalable_sizes=scalable_sizes,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -105,6 +105,10 @@ def _f64ArrayAttr(x, context):
|
||||
def _denseI64ArrayAttr(x, context):
|
||||
return DenseI64ArrayAttr.get(x, context=context)
|
||||
|
||||
@register_attribute_builder("DenseBoolArrayAttr")
|
||||
def _denseBoolArrayAttr(x, context):
|
||||
return DenseBoolArrayAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("TypeAttr")
|
||||
def _typeAttr(x, context):
|
||||
|
||||
@@ -220,25 +220,3 @@ transform.sequence failures(propagate) {
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Add support for for specyfying more than one scalable tile size
|
||||
|
||||
func.func @scalable_and_fixed_length_tile(
|
||||
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32> {
|
||||
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
|
||||
outs(%arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32>
|
||||
|
||||
return %0 : tensor<128x128xf32>
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{non-trailing index cannot be scalable}}
|
||||
// expected-error @below {{expected SSA value or integer}}
|
||||
%1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||
}
|
||||
|
||||
@@ -105,3 +105,11 @@ transform.sequence failures(propagate) {
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||
}
|
||||
|
||||
// CHECK: transform.sequence
|
||||
// CHECK: transform.structured.tile %0{{\[}}[2], 4, 8]
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.structured.tile %0 [[2], 4, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user