[MLIR] Adopt DenseI64ArrayAttr in tensor, memref and linalg transform

This commit is a first step toward removing inconsistencies between dynamic
and static attributes (i64 v. index) by dropping `I64ArrayAttr` and
using `DenseI64ArrayAttr` in Tensor, Memref and Linalg Transform ops.
In Linalg Transform ops only `TileToScfForOp` and `TileOp` have been updated.

See related discussion: https://discourse.llvm.org/t/rfc-inconsistency-between-dynamic-and-static-attributes-i64-v-index/66612/1

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D138567
This commit is contained in:
Lorenzo Chelini
2022-11-22 12:41:44 +01:00
parent 36f61d14fb
commit a9733b8a5e
18 changed files with 253 additions and 236 deletions

View File

@@ -839,8 +839,8 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
let arguments = (ins PDL_Operation:$target,
Variadic<PDL_Operation>:$dynamic_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
let results = (outs PDL_Operation:$tiled_linalg_op,
Variadic<PDL_Operation>:$loops);
@@ -917,8 +917,8 @@ def TileToForeachThreadOp :
let arguments = (ins PDL_Operation:$target,
Variadic<PDL_Operation>:$num_threads,
Variadic<PDL_Operation>:$tile_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_num_threads,
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_tile_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
let results = (outs PDL_Operation:$foreach_thread_op,
PDL_Operation:$tiled_op);
@@ -1009,8 +1009,8 @@ def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
let arguments = (ins PDL_Operation:$target,
Variadic<PDL_Operation>:$dynamic_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
let results = (outs PDL_Operation:$tiled_linalg_op,
Variadic<PDL_Operation>:$loops);

View File

@@ -1260,9 +1260,9 @@ def MemRef_ReinterpretCastOp
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I64ArrayAttr:$static_offsets,
I64ArrayAttr:$static_sizes,
I64ArrayAttr:$static_strides);
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{
@@ -1476,7 +1476,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
or copies.
A reassociation is defined as a grouping of dimensions and is represented
with an array of I64ArrayAttr attributes.
with an array of DenseI64ArrayAttr attributes.
Example:
@@ -1563,7 +1563,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
type.
A reassociation is defined as a continuous grouping of dimensions and is
represented with an array of I64ArrayAttr attribute.
represented with an array of DenseI64ArrayAttr attribute.
Note: Only the dimensions within a reassociation group must be contiguous.
The remaining dimensions may be non-contiguous.
@@ -1855,9 +1855,9 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I64ArrayAttr:$static_offsets,
I64ArrayAttr:$static_sizes,
I64ArrayAttr:$static_strides);
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{

View File

@@ -326,9 +326,9 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I64ArrayAttr:$static_offsets,
I64ArrayAttr:$static_sizes,
I64ArrayAttr:$static_strides
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);
@@ -807,9 +807,9 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I64ArrayAttr:$static_offsets,
I64ArrayAttr:$static_sizes,
I64ArrayAttr:$static_strides
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);
@@ -1013,7 +1013,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
rank whose sizes are a reassociation of the original `src`.
A reassociation is defined as a continuous grouping of dimensions and is
represented with an array of I64ArrayAttr attribute.
represented with an array of DenseI64ArrayAttr attribute.
The verification rule is that the reassociation maps are applied to the
result tensor with the higher rank to obtain the operand tensor with the
@@ -1065,7 +1065,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
rank whose sizes are a reassociation of the original `src`.
A reassociation is defined as a continuous grouping of dimensions and is
represented with an array of I64ArrayAttr attribute.
represented with an array of DenseI64ArrayAttr attribute.
The verification rule is that the reassociation maps are applied to the
operand tensor with the higher rank to obtain the result tensor with the
@@ -1206,8 +1206,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [
AnyTensor:$source,
Variadic<Index>:$low,
Variadic<Index>:$high,
I64ArrayAttr:$static_low,
I64ArrayAttr:$static_high,
DenseI64ArrayAttr:$static_low,
DenseI64ArrayAttr:$static_high,
UnitAttr:$nofold);
let regions = (region SizedRegion<1>:$region);
@@ -1254,16 +1254,17 @@ def Tensor_PadOp : Tensor_Op<"pad", [
// Return a vector of all the static or dynamic values (low/high padding) of
// the op.
inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayAttr staticAttrs,
inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayRef<int64_t> staticAttrs,
ValueRange values) {
Builder builder(*this);
SmallVector<OpFoldResult> res;
unsigned numDynamic = 0;
unsigned count = staticAttrs.size();
for (unsigned idx = 0; idx < count; ++idx) {
if (ShapedType::isDynamic(staticAttrs[idx].cast<IntegerAttr>().getInt()))
if (ShapedType::isDynamic(staticAttrs[idx]))
res.push_back(values[numDynamic++]);
else
res.push_back(staticAttrs[idx]);
res.push_back(builder.getI64IntegerAttr(staticAttrs[idx]));
}
return res;
}
@@ -1400,9 +1401,9 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I64ArrayAttr:$static_offsets,
I64ArrayAttr:$static_sizes,
I64ArrayAttr:$static_strides
DenseI64ArrayAttr:$static_offsets,
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides
);
let assemblyFormat = [{
$source `into` $dest ``
@@ -1748,7 +1749,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
I64ArrayAttr:$static_inner_tiles);
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source
@@ -1803,7 +1804,7 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
I64ArrayAttr:$static_inner_tiles);
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source

View File

@@ -87,6 +87,18 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec);
/// Return a vector of OpFoldResults with the same size a staticValues, but all
/// elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b);
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues);
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H

View File

@@ -21,18 +21,6 @@
namespace mlir {
/// Return a vector of OpFoldResults with the same size a staticValues, but all
/// elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
ValueRange dynamicValues);
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues);
class OffsetSizeAndStrideOpInterface;
namespace detail {
@@ -61,7 +49,7 @@ namespace mlir {
/// idiomatic printing of mixed value and integer attributes in a list. E.g.
/// `[%arg0, 7, 42, %arg42]`.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values, ArrayAttr integers);
OperandRange values, ArrayRef<int64_t> integers);
/// Pasrer hook for custom directive in assemblyFormat.
///
@@ -79,13 +67,14 @@ void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
ParseResult
parseDynamicIndexList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
ArrayAttr &integers);
DenseI64ArrayAttr &integers);
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name,
unsigned expectedNumElements,
ArrayAttr attr, ValueRange values);
ArrayRef<int64_t> attr,
ValueRange values);
} // namespace mlir

View File

@@ -124,7 +124,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*desc=*/[{
Return the static offset attributes.
}],
/*retTy=*/"::mlir::ArrayAttr",
/*retTy=*/"::llvm::ArrayRef<int64_t>",
/*methodName=*/"static_offsets",
/*args=*/(ins),
/*methodBody=*/"",
@@ -136,7 +136,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*desc=*/[{
Return the static size attributes.
}],
/*retTy=*/"::mlir::ArrayAttr",
/*retTy=*/"::llvm::ArrayRef<int64_t>",
/*methodName=*/"static_sizes",
/*args=*/(ins),
/*methodBody=*/"",
@@ -148,7 +148,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*desc=*/[{
Return the dynamic stride attributes.
}],
/*retTy=*/"::mlir::ArrayAttr",
/*retTy=*/"::llvm::ArrayRef<int64_t>",
/*methodName=*/"static_strides",
/*args=*/(ins),
/*methodBody=*/"",
@@ -165,8 +165,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Builder b($_op->getContext());
return ::mlir::getMixedValues($_op.getStaticOffsets(),
$_op.getOffsets());
$_op.getOffsets(), b);
}]
>,
InterfaceMethod<
@@ -178,7 +179,8 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes());
Builder b($_op->getContext());
return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes(), b);
}]
>,
InterfaceMethod<
@@ -190,8 +192,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Builder b($_op->getContext());
return ::mlir::getMixedValues($_op.getStaticStrides(),
$_op.getStrides());
$_op.getStrides(), b);
}]
>,
@@ -202,9 +205,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins "unsigned":$idx),
/*methodBody=*/"",
/*defaultImplementation=*/[{
::llvm::APInt v = *(static_offsets()
.template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
return ::mlir::ShapedType::isDynamic(v.getSExtValue());
return ::mlir::ShapedType::isDynamic(static_offsets()[idx]);
}]
>,
InterfaceMethod<
@@ -214,9 +215,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins "unsigned":$idx),
/*methodBody=*/"",
/*defaultImplementation=*/[{
::llvm::APInt v = *(static_sizes()
.template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
return ::mlir::ShapedType::isDynamic(v.getSExtValue());
return ::mlir::ShapedType::isDynamic(static_sizes()[idx]);
}]
>,
InterfaceMethod<
@@ -226,9 +225,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins "unsigned":$idx),
/*methodBody=*/"",
/*defaultImplementation=*/[{
::llvm::APInt v = *(static_strides()
.template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
return ::mlir::ShapedType::isDynamic(v.getSExtValue());
return ::mlir::ShapedType::isDynamic(static_strides()[idx]);
}]
>,
InterfaceMethod<
@@ -241,9 +238,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(!$_op.isDynamicOffset(idx) && "expected static offset");
::llvm::APInt v = *(static_offsets().
template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
return v.getSExtValue();
return static_offsets()[idx];
}]
>,
InterfaceMethod<
@@ -256,9 +251,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(!$_op.isDynamicSize(idx) && "expected static size");
::llvm::APInt v = *(static_sizes().
template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
return v.getSExtValue();
return static_sizes()[idx];
}]
>,
InterfaceMethod<
@@ -271,9 +264,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(!$_op.isDynamicStride(idx) && "expected static stride");
::llvm::APInt v = *(static_strides().
template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
return v.getSExtValue();
return static_strides()[idx];
}]
>,
@@ -289,7 +280,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*defaultImplementation=*/[{
assert($_op.isDynamicOffset(idx) && "expected dynamic offset");
auto numDynamic = getNumDynamicEntriesUpToIdx(
static_offsets().template cast<::mlir::ArrayAttr>(),
static_offsets(),
::mlir::ShapedType::isDynamic,
idx);
return $_op.getOffsetSizeAndStrideStartOperandIndex() + numDynamic;
@@ -307,7 +298,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*defaultImplementation=*/[{
assert($_op.isDynamicSize(idx) && "expected dynamic size");
auto numDynamic = getNumDynamicEntriesUpToIdx(
static_sizes().template cast<::mlir::ArrayAttr>(), ::mlir::ShapedType::isDynamic, idx);
static_sizes(), ::mlir::ShapedType::isDynamic, idx);
return $_op.getOffsetSizeAndStrideStartOperandIndex() +
offsets().size() + numDynamic;
}]
@@ -324,7 +315,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*defaultImplementation=*/[{
assert($_op.isDynamicStride(idx) && "expected dynamic stride");
auto numDynamic = getNumDynamicEntriesUpToIdx(
static_strides().template cast<::mlir::ArrayAttr>(),
static_strides(),
::mlir::ShapedType::isDynamic,
idx);
return $_op.getOffsetSizeAndStrideStartOperandIndex() +
@@ -333,20 +324,20 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
>,
InterfaceMethod<
/*desc=*/[{
Helper method to compute the number of dynamic entries of `attr`, up to
Helper method to compute the number of dynamic entries of `staticVals`, up to
`idx` using `isDynamic` to determine whether an entry is dynamic.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumDynamicEntriesUpToIdx",
/*args=*/(ins "::mlir::ArrayAttr":$attr,
/*args=*/(ins "::llvm::ArrayRef<int64_t>":$staticVals,
"::llvm::function_ref<bool(int64_t)>":$isDynamic,
"unsigned":$idx),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return std::count_if(
attr.getValue().begin(), attr.getValue().begin() + idx,
[&](::mlir::Attribute attr) {
return isDynamic(attr.cast<::mlir::IntegerAttr>().getInt());
staticVals.begin(), staticVals.begin() + idx,
[&](int64_t val) {
return isDynamic(val);
});
}]
>,

View File

@@ -1705,10 +1705,8 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
auto viewMemRefType = subViewOp.getType();
auto inferredType =
memref::SubViewOp::inferResultType(
subViewOp.getSourceType(),
extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
subViewOp.getSourceType(), subViewOp.getStaticOffsets(),
subViewOp.getStaticSizes(), subViewOp.getStaticStrides())
.cast<MemRefType>();
auto targetElementTy =
typeConverter->convertType(viewMemRefType.getElementType());

View File

@@ -30,8 +30,8 @@ public:
PatternRewriter &rewriter) const final {
Location loc = sliceOp.getLoc();
Value input = sliceOp.getInput();
SmallVector<int64_t> strides, sizes;
auto starts = sliceOp.getStart();
SmallVector<int64_t> strides, sizes, starts;
starts = extractFromI64ArrayAttr(sliceOp.getStart());
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
SmallVector<Value> dynSizes;
@@ -44,15 +44,15 @@ public:
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
auto offset = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getIndexAttr(starts[index].cast<IntegerAttr>().getInt()));
loc, rewriter.getIndexAttr(starts[index]));
dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
}
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
ValueRange({}), starts, rewriter.getI64ArrayAttr(sizes),
rewriter.getI64ArrayAttr(strides));
ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
rewriter.getDenseI64ArrayAttr(sizes),
rewriter.getDenseI64ArrayAttr(strides));
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
return success();

View File

@@ -40,16 +40,6 @@ static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
return result;
}
/// Extracts a vector of int64_t from an array attribute. Asserts if the
/// attribute contains values other than integers.
static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
SmallVector<int64_t> result;
result.reserve(attr.size());
for (APInt value : attr.getAsValueRange<IntegerAttr>())
result.push_back(value.getSExtValue());
return result;
}
namespace {
/// A simple pattern rewriter that implements no special logic.
class SimpleRewriter : public PatternRewriter {
@@ -1205,7 +1195,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
DiagnosedSilenceableFailure
transform::TileOp::apply(TransformResults &transformResults,
TransformState &state) {
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
ArrayRef<int64_t> tileSizes = getStaticSizes();
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
@@ -1270,7 +1260,7 @@ transform::TileOp::apply(TransformResults &transformResults,
});
}
tilingOptions.setInterchange(extractI64Array(getInterchange()));
tilingOptions.setInterchange(getInterchange());
SimpleRewriter rewriter(linalgOp.getContext());
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
@@ -1298,7 +1288,7 @@ transform::TileOp::apply(TransformResults &transformResults,
SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
ValueRange dynamic = getDynamicSizes();
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
ArrayRef<int64_t> tileSizes = getStaticSizes();
SmallVector<OpFoldResult> results;
results.reserve(tileSizes.size());
unsigned dynamicPos = 0;
@@ -1313,22 +1303,51 @@ SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
return results;
}
// We want to parse `DenseI64ArrayAttr` using the short form without the
// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
ParseResult parseOptionalInterchange(OpAsmParser &parser,
OperationState &result) {
if (succeeded(parser.parseOptionalLBrace())) {
if (failed(parser.parseKeyword("interchange")))
return parser.emitError(parser.getNameLoc()) << "expect `interchange`";
if (failed(parser.parseEqual()))
return parser.emitError(parser.getNameLoc()) << "expect `=`";
result.addAttribute("interchange",
DenseI64ArrayAttr::parse(parser, Type{}));
if (failed(parser.parseRBrace()))
return parser.emitError(parser.getNameLoc()) << "expect `}`";
}
return success();
}
void printOptionalInterchange(OpAsmPrinter &p,
ArrayRef<int64_t> interchangeVals) {
if (!interchangeVals.empty()) {
p << " {interchange = [";
llvm::interleaveComma(interchangeVals, p,
[&](int64_t integer) { p << integer; });
p << "]}";
}
}
ParseResult transform::TileOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
ArrayAttr staticSizes;
DenseI64ArrayAttr staticSizes;
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
if (parser.parseOperand(target) ||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
parser.parseOptionalAttrDict(result.attributes))
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
return ParseResult::failure();
// Parse optional interchange.
if (failed(parseOptionalInterchange(parser, result)))
return ParseResult::failure();
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
size_t numExpectedLoops =
staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
return success();
}
@@ -1336,7 +1355,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
void TileOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
printOptionalInterchange(p, getInterchange());
}
void transform::TileOp::getEffects(
@@ -1379,13 +1398,13 @@ void transform::TileToForeachThreadOp::build(
// bugs ensue.
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
build(builder, result,
/*resultTypes=*/TypeRange{operationType, operationType},
/*target=*/target,
/*num_threads=*/ValueRange{},
/*tile_sizes=*/dynamicTileSizes,
/*static_num_threads=*/builder.getI64ArrayAttr({}),
/*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
/*static_tile_sizes=*/staticTileSizesAttr,
/*mapping=*/mapping);
}
@@ -1414,14 +1433,14 @@ void transform::TileToForeachThreadOp::build(
// bugs ensue.
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
build(builder, result,
/*resultTypes=*/TypeRange{operationType, operationType},
/*target=*/target,
/*num_threads=*/dynamicNumThreads,
/*tile_sizes=*/ValueRange{},
/*static_num_threads=*/staticNumThreadsAttr,
/*static_tile_sizes=*/builder.getI64ArrayAttr({}),
/*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
/*mapping=*/mapping);
}
@@ -1547,11 +1566,13 @@ void transform::TileToForeachThreadOp::getEffects(
}
SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
return getMixedValues(getStaticNumThreads(), getNumThreads());
Builder b(getContext());
return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
}
SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
return getMixedValues(getStaticTileSizes(), getTileSizes());
Builder b(getContext());
return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
}
LogicalResult TileToForeachThreadOp::verify() {
@@ -1567,7 +1588,7 @@ LogicalResult TileToForeachThreadOp::verify() {
DiagnosedSilenceableFailure
transform::TileToScfForOp::apply(TransformResults &transformResults,
TransformState &state) {
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
ArrayRef<int64_t> tileSizes = getStaticSizes();
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
@@ -1632,7 +1653,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
});
}
tilingOptions.setInterchange(extractI64Array(getInterchange()));
tilingOptions.setInterchange(getInterchange());
SimpleRewriter rewriter(tilingInterfaceOp.getContext());
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
@@ -1655,7 +1676,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
SmallVector<OpFoldResult> transform::TileToScfForOp::getMixedSizes() {
ValueRange dynamic = getDynamicSizes();
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
ArrayRef<int64_t> tileSizes = getStaticSizes();
SmallVector<OpFoldResult> results;
results.reserve(tileSizes.size());
unsigned dynamicPos = 0;
@@ -1674,18 +1695,20 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
ArrayAttr staticSizes;
DenseI64ArrayAttr staticSizes;
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
if (parser.parseOperand(target) ||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
parser.parseOptionalAttrDict(result.attributes))
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
return ParseResult::failure();
// Parse optional interchange.
if (failed(parseOptionalInterchange(parser, result)))
return ParseResult::failure();
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
size_t numExpectedLoops =
staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
return success();
}
@@ -1693,7 +1716,7 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
void TileToScfForOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
printOptionalInterchange(p, getInterchange());
}
void transform::TileToScfForOp::getEffects(

View File

@@ -348,7 +348,7 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
SmallVector<AffineExpr, 4> outputExprs;
for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
padOp.getStaticLow()[i].cast<IntegerAttr>().getInt());
padOp.getStaticLow()[i]);
}
SmallVector<AffineMap, 2> transferMaps = {

View File

@@ -1776,8 +1776,9 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamic);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
b.getDenseI64ArrayAttr(staticSizes),
b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -1823,8 +1824,8 @@ LogicalResult ReinterpretCastOp::verify() {
<< srcType << " and result memref type " << resultType;
// Match sizes in result memref type and in static_sizes attribute.
for (auto &en : llvm::enumerate(llvm::zip(
resultType.getShape(), extractFromI64ArrayAttr(getStaticSizes())))) {
for (auto &en :
llvm::enumerate(llvm::zip(resultType.getShape(), getStaticSizes()))) {
int64_t resultSize = std::get<0>(en.value());
int64_t expectedSize = std::get<1>(en.value());
if (!ShapedType::isDynamic(resultSize) &&
@@ -1844,7 +1845,7 @@ LogicalResult ReinterpretCastOp::verify() {
<< resultType;
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = extractFromI64ArrayAttr(getStaticOffsets()).front();
int64_t expectedOffset = getStaticOffsets().front();
if (!ShapedType::isDynamic(resultOffset) &&
!ShapedType::isDynamic(expectedOffset) &&
resultOffset != expectedOffset)
@@ -1852,8 +1853,8 @@ LogicalResult ReinterpretCastOp::verify() {
<< resultOffset << " instead of " << expectedOffset;
// Match strides in result memref type and in static_strides attribute.
for (auto &en : llvm::enumerate(llvm::zip(
resultStrides, extractFromI64ArrayAttr(getStaticStrides())))) {
for (auto &en :
llvm::enumerate(llvm::zip(resultStrides, getStaticStrides()))) {
int64_t resultStride = std::get<0>(en.value());
int64_t expectedStride = std::get<1>(en.value());
if (!ShapedType::isDynamic(resultStride) &&
@@ -2665,8 +2666,9 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
.cast<MemRefType>();
}
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
b.getDenseI64ArrayAttr(staticSizes),
b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -2831,9 +2833,7 @@ LogicalResult SubViewOp::verify() {
// Verify result type against inferred type.
auto expectedType = SubViewOp::inferResultType(
baseType, extractFromI64ArrayAttr(getStaticOffsets()),
extractFromI64ArrayAttr(getStaticSizes()),
extractFromI64ArrayAttr(getStaticStrides()));
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
subViewType, getMixedSizes());

View File

@@ -45,9 +45,8 @@ static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
builder.setInsertionPoint(subviewUse);
Type newType = memref::SubViewOp::inferRankReducedResultType(
subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
extractFromI64ArrayAttr(subviewUse.getStaticOffsets()),
extractFromI64ArrayAttr(subviewUse.getStaticSizes()),
extractFromI64ArrayAttr(subviewUse.getStaticStrides()));
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
Value newSubview = builder.create<memref::SubViewOp>(
subviewUse->getLoc(), newType.cast<MemRefType>(), val,
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),

View File

@@ -337,8 +337,7 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
auto dimMask = computeRankReductionMask(
extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
extractOperand.getType().getShape());
extractOperand.getStaticSizes(), extractOperand.getType().getShape());
size_t dimIndex = 0;
for (size_t i = 0, e = sizes.size(); i < e; i++) {
if (dimMask && dimMask->count(i))
@@ -1713,8 +1712,9 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
.cast<RankedTensorType>();
}
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
b.getDenseI64ArrayAttr(staticSizes),
b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -1949,13 +1949,13 @@ public:
return failure();
// Check if there are any dynamic parts, which are not supported.
auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
auto offsets = op.getStaticOffsets();
if (llvm::is_contained(offsets, ShapedType::kDynamic))
return failure();
auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
auto sizes = op.getStaticSizes();
if (llvm::is_contained(sizes, ShapedType::kDynamic))
return failure();
auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
auto strides = op.getStaticStrides();
if (llvm::is_contained(strides, ShapedType::kDynamic))
return failure();
@@ -2124,8 +2124,9 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamic);
build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
b.getDenseI64ArrayAttr(staticSizes),
b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -2153,17 +2154,14 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
/// Rank-reducing type verification for both InsertSliceOp and
/// ParallelInsertSliceOp.
static SliceVerificationResult
verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
ArrayAttr staticOffsets, ArrayAttr staticSizes,
ArrayAttr staticStrides,
ShapedType *expectedType = nullptr) {
static SliceVerificationResult verifyInsertSliceOp(
ShapedType srcType, ShapedType dstType, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
ShapedType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
RankedTensorType expected = ExtractSliceOp::inferResultType(
dstType, extractFromI64ArrayAttr(staticOffsets),
extractFromI64ArrayAttr(staticSizes),
extractFromI64ArrayAttr(staticStrides));
dstType, staticOffsets, staticSizes, staticStrides);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
@@ -2482,9 +2480,8 @@ ParseResult parseInferType(OpAsmParser &parser,
LogicalResult PadOp::verify() {
auto sourceType = getSource().getType().cast<RankedTensorType>();
auto resultType = getResult().getType().cast<RankedTensorType>();
auto expectedType = PadOp::inferResultType(
sourceType, extractFromI64ArrayAttr(getStaticLow()),
extractFromI64ArrayAttr(getStaticHigh()));
auto expectedType =
PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
continue;
@@ -2556,8 +2553,9 @@ void PadOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
auto resultType = inferResultType(sourceType, staticLow, staticHigh);
build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
build(b, result, resultType, source, low, high,
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
nofold ? b.getUnitAttr() : UnitAttr());
result.addAttributes(attrs);
}
@@ -2591,7 +2589,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
}
assert(resultType.isa<RankedTensorType>());
build(b, result, resultType, source, dynamicLow, dynamicHigh,
b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
nofold ? b.getUnitAttr() : UnitAttr());
result.addAttributes(attrs);
}
@@ -2658,8 +2656,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
auto newResultType = PadOp::inferResultType(
castOp.getSource().getType().cast<RankedTensorType>(),
extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
padTensorOp.getResultType().getShape());
if (newResultType == padTensorOp.getResultType()) {
@@ -2940,8 +2937,9 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamic);
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
b.getDenseI64ArrayAttr(staticSizes),
b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -3086,12 +3084,12 @@ template <typename OpTy>
static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
Builder builder(op);
SmallVector<OpFoldResult> mixedInnerTiles;
unsigned dynamicValIndex = 0;
for (Attribute attr : op.getStaticInnerTiles()) {
auto tileAttr = attr.cast<IntegerAttr>();
if (!ShapedType::isDynamic(tileAttr.getInt()))
mixedInnerTiles.push_back(tileAttr);
for (int64_t staticTile : op.getStaticInnerTiles()) {
if (!ShapedType::isDynamic(staticTile))
mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
else
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
}

View File

@@ -137,4 +137,41 @@ SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
return getValueOrCreateConstantIndexOp(b, loc, value);
}));
}
/// Return a vector of OpFoldResults with the same size a staticValues, but all
/// elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b) {
SmallVector<OpFoldResult> res;
res.reserve(staticValues.size());
unsigned numDynamic = 0;
unsigned count = static_cast<unsigned>(staticValues.size());
for (unsigned idx = 0; idx < count; ++idx) {
int64_t value = staticValues[idx];
res.push_back(ShapedType::isDynamic(value)
? OpFoldResult{dynamicValues[numDynamic++]}
: OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
}
return res;
}
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues) {
SmallVector<int64_t> staticValues;
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
if (it.is<Attribute>()) {
staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
} else {
staticValues.push_back(ShapedType::kDynamic);
dynamicValues.push_back(it.get<Value>());
}
}
return {b.getI64ArrayAttr(staticValues), dynamicValues};
}
} // namespace mlir

View File

@@ -20,15 +20,15 @@ using namespace mlir;
LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
StringRef name,
unsigned numElements,
ArrayAttr attr,
ArrayRef<int64_t> staticVals,
ValueRange values) {
/// Check static and dynamic offsets/sizes/strides does not overflow type.
if (attr.size() != numElements)
// Check static and dynamic offsets/sizes/strides does not overflow type.
if (staticVals.size() != numElements)
return op->emitError("expected ")
<< numElements << " " << name << " values";
unsigned expectedNumDynamicEntries =
llvm::count_if(attr.getValue(), [&](Attribute attr) {
return ShapedType::isDynamic(attr.cast<IntegerAttr>().getInt());
llvm::count_if(staticVals, [&](int64_t staticVal) {
return ShapedType::isDynamic(staticVal);
});
if (values.size() != expectedNumDynamicEntries)
return op->emitError("expected ")
@@ -70,19 +70,19 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
}
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values, ArrayAttr integers) {
OperandRange values,
ArrayRef<int64_t> integers) {
printer << '[';
if (integers.empty()) {
printer << "]";
return;
}
unsigned idx = 0;
llvm::interleaveComma(integers, printer, [&](Attribute a) {
int64_t val = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamic(val))
llvm::interleaveComma(integers, printer, [&](int64_t integer) {
if (ShapedType::isDynamic(integer))
printer << values[idx++];
else
printer << val;
printer << integer;
});
printer << ']';
}
@@ -90,28 +90,28 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
ParseResult mlir::parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
ArrayAttr &integers) {
DenseI64ArrayAttr &integers) {
if (failed(parser.parseLSquare()))
return failure();
// 0-D.
if (succeeded(parser.parseOptionalRSquare())) {
integers = parser.getBuilder().getArrayAttr({});
integers = parser.getBuilder().getDenseI64ArrayAttr({});
return success();
}
SmallVector<int64_t, 4> attrVals;
SmallVector<int64_t, 4> integerVals;
while (true) {
OpAsmParser::UnresolvedOperand operand;
auto res = parser.parseOptionalOperand(operand);
if (res.has_value() && succeeded(res.value())) {
values.push_back(operand);
attrVals.push_back(ShapedType::kDynamic);
integerVals.push_back(ShapedType::kDynamic);
} else {
IntegerAttr attr;
if (failed(parser.parseAttribute<IntegerAttr>(attr)))
int64_t integer;
if (failed(parser.parseInteger(integer)))
return parser.emitError(parser.getNameLoc())
<< "expected SSA value or integer";
attrVals.push_back(attr.getInt());
integerVals.push_back(integer);
}
if (succeeded(parser.parseOptionalComma()))
@@ -120,7 +120,7 @@ ParseResult mlir::parseDynamicIndexList(
return failure();
break;
}
integers = parser.getBuilder().getI64ArrayAttr(attrVals);
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
return success();
}
@@ -144,34 +144,3 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
return false;
return true;
}
SmallVector<OpFoldResult, 4> mlir::getMixedValues(ArrayAttr staticValues,
ValueRange dynamicValues) {
SmallVector<OpFoldResult, 4> res;
res.reserve(staticValues.size());
unsigned numDynamic = 0;
unsigned count = static_cast<unsigned>(staticValues.size());
for (unsigned idx = 0; idx < count; ++idx) {
APInt value = staticValues[idx].cast<IntegerAttr>().getValue();
res.push_back(ShapedType::isDynamic(value.getSExtValue())
? OpFoldResult{dynamicValues[numDynamic++]}
: OpFoldResult{staticValues[idx]});
}
return res;
}
std::pair<ArrayAttr, SmallVector<Value>>
mlir::decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues) {
SmallVector<int64_t> staticValues;
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
if (it.is<Attribute>()) {
staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
} else {
staticValues.push_back(ShapedType::kDynamic);
dynamicValues.push_back(it.get<Value>());
}
}
return {b.getI64ArrayAttr(staticValues), dynamicValues};
}

View File

@@ -49,6 +49,15 @@ def _get_int_array_attr(
return ArrayAttr.get([_get_int64_attr(v) for v in values])
def _get_dense_int64_array_attr(
values: Sequence[int]) -> DenseI64ArrayAttr:
"""Creates a dense integer array from a sequence of integers.
Expects the thread-local MLIR context to have been set by the context
manager.
"""
if values is None:
return DenseI64ArrayAttr.get([])
return DenseI64ArrayAttr.get(values)
def _get_int_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
@@ -250,14 +259,11 @@ class TileOp:
else:
for size in sizes:
if isinstance(size, int):
static_sizes.append(IntegerAttr.get(i64_type, size))
elif isinstance(size, IntegerAttr):
static_sizes.append(size)
else:
static_sizes.append(
IntegerAttr.get(i64_type, ShapedType.get_dynamic_size()))
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(_get_op_result_or_value(size))
sizes_attr = ArrayAttr.get(static_sizes)
sizes_attr = DenseI64ArrayAttr.get(static_sizes)
num_loops = sum(
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
@@ -266,14 +272,14 @@ class TileOp:
_get_op_result_or_value(target),
dynamic_sizes=dynamic_sizes,
static_sizes=sizes_attr,
interchange=_get_int_array_attr(interchange) if interchange else None,
interchange=_get_dense_int64_array_attr(interchange) if interchange else None,
loc=loc,
ip=ip)
def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]:
def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
if not attr:
return []
return [IntegerAttr(element).value for element in attr]
return [element for element in attr]
class VectorizeOp:

View File

@@ -138,7 +138,7 @@ func.func @permute_generic(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]}
transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]}
}
// CHECK-LABEL: func @permute_generic
@@ -191,8 +191,8 @@ func.func @matmul_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange=[1, 2, 0]}
%2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange=[1, 0, 2]}
%1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]}
%2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]}
%3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40]
}

View File

@@ -108,7 +108,6 @@ def testSplit():
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
@run
def testTileCompact():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
@@ -120,14 +119,11 @@ def testTileCompact():
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
# CHECK: interchange = [0, 1]
@run
def testTileAttributes():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
attr = ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
ichange = ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
transform.YieldOp()
@@ -136,7 +132,6 @@ def testTileAttributes():
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
# CHECK: interchange = [0, 1]
@run
def testTileZero():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
@@ -149,7 +144,6 @@ def testTileZero():
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
# CHECK: interchange = [0, 1, 2, 3]
@run
def testTileDynamic():
with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get())