mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 14:48:27 +08:00
[MLIR] Adopt DenseI64ArrayAttr in tensor, memref and linalg transform
This commit is a first step toward removing inconsistencies between dynamic and static attributes (i64 v. index) by dropping `I64ArrayAttr` and using `DenseI64ArrayAttr` in Tensor, Memref and Linalg Transform ops. In Linalg Transform ops only `TileToScfForOp` and `TileOp` have been updated. See related discussion: https://discourse.llvm.org/t/rfc-inconsistency-between-dynamic-and-static-attributes-i64-v-index/66612/1 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D138567
This commit is contained in:
@@ -839,8 +839,8 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
Variadic<PDL_Operation>:$dynamic_sizes,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_sizes,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
|
||||
let results = (outs PDL_Operation:$tiled_linalg_op,
|
||||
Variadic<PDL_Operation>:$loops);
|
||||
|
||||
@@ -917,8 +917,8 @@ def TileToForeachThreadOp :
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
Variadic<PDL_Operation>:$num_threads,
|
||||
Variadic<PDL_Operation>:$tile_sizes,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_num_threads,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_tile_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
|
||||
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
|
||||
let results = (outs PDL_Operation:$foreach_thread_op,
|
||||
PDL_Operation:$tiled_op);
|
||||
@@ -1009,8 +1009,8 @@ def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
Variadic<PDL_Operation>:$dynamic_sizes,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_sizes,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
|
||||
let results = (outs PDL_Operation:$tiled_linalg_op,
|
||||
Variadic<PDL_Operation>:$loops);
|
||||
|
||||
|
||||
@@ -1260,9 +1260,9 @@ def MemRef_ReinterpretCastOp
|
||||
Variadic<Index>:$offsets,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides,
|
||||
I64ArrayAttr:$static_offsets,
|
||||
I64ArrayAttr:$static_sizes,
|
||||
I64ArrayAttr:$static_strides);
|
||||
DenseI64ArrayAttr:$static_offsets,
|
||||
DenseI64ArrayAttr:$static_sizes,
|
||||
DenseI64ArrayAttr:$static_strides);
|
||||
let results = (outs AnyMemRef:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
@@ -1476,7 +1476,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
|
||||
or copies.
|
||||
|
||||
A reassociation is defined as a grouping of dimensions and is represented
|
||||
with an array of I64ArrayAttr attributes.
|
||||
with an array of DenseI64ArrayAttr attributes.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -1563,7 +1563,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
|
||||
type.
|
||||
|
||||
A reassociation is defined as a continuous grouping of dimensions and is
|
||||
represented with an array of I64ArrayAttr attribute.
|
||||
represented with an array of DenseI64ArrayAttr attribute.
|
||||
|
||||
Note: Only the dimensions within a reassociation group must be contiguous.
|
||||
The remaining dimensions may be non-contiguous.
|
||||
@@ -1855,9 +1855,9 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
|
||||
Variadic<Index>:$offsets,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides,
|
||||
I64ArrayAttr:$static_offsets,
|
||||
I64ArrayAttr:$static_sizes,
|
||||
I64ArrayAttr:$static_strides);
|
||||
DenseI64ArrayAttr:$static_offsets,
|
||||
DenseI64ArrayAttr:$static_sizes,
|
||||
DenseI64ArrayAttr:$static_strides);
|
||||
let results = (outs AnyMemRef:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
||||
@@ -326,9 +326,9 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
|
||||
Variadic<Index>:$offsets,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides,
|
||||
I64ArrayAttr:$static_offsets,
|
||||
I64ArrayAttr:$static_sizes,
|
||||
I64ArrayAttr:$static_strides
|
||||
DenseI64ArrayAttr:$static_offsets,
|
||||
DenseI64ArrayAttr:$static_sizes,
|
||||
DenseI64ArrayAttr:$static_strides
|
||||
);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
@@ -807,9 +807,9 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
|
||||
Variadic<Index>:$offsets,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides,
|
||||
I64ArrayAttr:$static_offsets,
|
||||
I64ArrayAttr:$static_sizes,
|
||||
I64ArrayAttr:$static_strides
|
||||
DenseI64ArrayAttr:$static_offsets,
|
||||
DenseI64ArrayAttr:$static_sizes,
|
||||
DenseI64ArrayAttr:$static_strides
|
||||
);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
@@ -1013,7 +1013,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
|
||||
rank whose sizes are a reassociation of the original `src`.
|
||||
|
||||
A reassociation is defined as a continuous grouping of dimensions and is
|
||||
represented with an array of I64ArrayAttr attribute.
|
||||
represented with an array of DenseI64ArrayAttr attribute.
|
||||
|
||||
The verification rule is that the reassociation maps are applied to the
|
||||
result tensor with the higher rank to obtain the operand tensor with the
|
||||
@@ -1065,7 +1065,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
|
||||
rank whose sizes are a reassociation of the original `src`.
|
||||
|
||||
A reassociation is defined as a continuous grouping of dimensions and is
|
||||
represented with an array of I64ArrayAttr attribute.
|
||||
represented with an array of DenseI64ArrayAttr attribute.
|
||||
|
||||
The verification rule is that the reassociation maps are applied to the
|
||||
operand tensor with the higher rank to obtain the result tensor with the
|
||||
@@ -1206,8 +1206,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [
|
||||
AnyTensor:$source,
|
||||
Variadic<Index>:$low,
|
||||
Variadic<Index>:$high,
|
||||
I64ArrayAttr:$static_low,
|
||||
I64ArrayAttr:$static_high,
|
||||
DenseI64ArrayAttr:$static_low,
|
||||
DenseI64ArrayAttr:$static_high,
|
||||
UnitAttr:$nofold);
|
||||
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
@@ -1254,16 +1254,17 @@ def Tensor_PadOp : Tensor_Op<"pad", [
|
||||
|
||||
// Return a vector of all the static or dynamic values (low/high padding) of
|
||||
// the op.
|
||||
inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayAttr staticAttrs,
|
||||
inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayRef<int64_t> staticAttrs,
|
||||
ValueRange values) {
|
||||
Builder builder(*this);
|
||||
SmallVector<OpFoldResult> res;
|
||||
unsigned numDynamic = 0;
|
||||
unsigned count = staticAttrs.size();
|
||||
for (unsigned idx = 0; idx < count; ++idx) {
|
||||
if (ShapedType::isDynamic(staticAttrs[idx].cast<IntegerAttr>().getInt()))
|
||||
if (ShapedType::isDynamic(staticAttrs[idx]))
|
||||
res.push_back(values[numDynamic++]);
|
||||
else
|
||||
res.push_back(staticAttrs[idx]);
|
||||
res.push_back(builder.getI64IntegerAttr(staticAttrs[idx]));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -1400,9 +1401,9 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
|
||||
Variadic<Index>:$offsets,
|
||||
Variadic<Index>:$sizes,
|
||||
Variadic<Index>:$strides,
|
||||
I64ArrayAttr:$static_offsets,
|
||||
I64ArrayAttr:$static_sizes,
|
||||
I64ArrayAttr:$static_strides
|
||||
DenseI64ArrayAttr:$static_offsets,
|
||||
DenseI64ArrayAttr:$static_sizes,
|
||||
DenseI64ArrayAttr:$static_strides
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$source `into` $dest ``
|
||||
@@ -1748,7 +1749,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
|
||||
DenseI64ArrayAttr:$inner_dims_pos,
|
||||
Variadic<Index>:$inner_tiles,
|
||||
I64ArrayAttr:$static_inner_tiles);
|
||||
DenseI64ArrayAttr:$static_inner_tiles);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
let assemblyFormat = [{
|
||||
$source
|
||||
@@ -1803,7 +1804,7 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
|
||||
DenseI64ArrayAttr:$inner_dims_pos,
|
||||
Variadic<Index>:$inner_tiles,
|
||||
I64ArrayAttr:$static_inner_tiles);
|
||||
DenseI64ArrayAttr:$static_inner_tiles);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
let assemblyFormat = [{
|
||||
$source
|
||||
|
||||
@@ -87,6 +87,18 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
|
||||
SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
|
||||
ArrayRef<OpFoldResult> valueOrAttrVec);
|
||||
|
||||
/// Return a vector of OpFoldResults with the same size a staticValues, but all
|
||||
/// elements for which ShapedType::isDynamic is true, will be replaced by
|
||||
/// dynamicValues.
|
||||
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
|
||||
ValueRange dynamicValues, Builder &b);
|
||||
|
||||
/// Decompose a vector of mixed static or dynamic values into the corresponding
|
||||
/// pair of arrays. This is the inverse function of `getMixedValues`.
|
||||
std::pair<ArrayAttr, SmallVector<Value>>
|
||||
decomposeMixedValues(Builder &b,
|
||||
const SmallVectorImpl<OpFoldResult> &mixedValues);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
|
||||
|
||||
@@ -21,18 +21,6 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Return a vector of OpFoldResults with the same size a staticValues, but all
|
||||
/// elements for which ShapedType::isDynamic is true, will be replaced by
|
||||
/// dynamicValues.
|
||||
SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
|
||||
ValueRange dynamicValues);
|
||||
|
||||
/// Decompose a vector of mixed static or dynamic values into the corresponding
|
||||
/// pair of arrays. This is the inverse function of `getMixedValues`.
|
||||
std::pair<ArrayAttr, SmallVector<Value>>
|
||||
decomposeMixedValues(Builder &b,
|
||||
const SmallVectorImpl<OpFoldResult> &mixedValues);
|
||||
|
||||
class OffsetSizeAndStrideOpInterface;
|
||||
|
||||
namespace detail {
|
||||
@@ -61,7 +49,7 @@ namespace mlir {
|
||||
/// idiomatic printing of mixed value and integer attributes in a list. E.g.
|
||||
/// `[%arg0, 7, 42, %arg42]`.
|
||||
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
OperandRange values, ArrayAttr integers);
|
||||
OperandRange values, ArrayRef<int64_t> integers);
|
||||
|
||||
/// Pasrer hook for custom directive in assemblyFormat.
|
||||
///
|
||||
@@ -79,13 +67,14 @@ void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
ParseResult
|
||||
parseDynamicIndexList(OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
ArrayAttr &integers);
|
||||
DenseI64ArrayAttr &integers);
|
||||
|
||||
/// Verify that a the `values` has as many elements as the number of entries in
|
||||
/// `attr` for which `isDynamic` evaluates to true.
|
||||
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name,
|
||||
unsigned expectedNumElements,
|
||||
ArrayAttr attr, ValueRange values);
|
||||
ArrayRef<int64_t> attr,
|
||||
ValueRange values);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*desc=*/[{
|
||||
Return the static offset attributes.
|
||||
}],
|
||||
/*retTy=*/"::mlir::ArrayAttr",
|
||||
/*retTy=*/"::llvm::ArrayRef<int64_t>",
|
||||
/*methodName=*/"static_offsets",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
@@ -136,7 +136,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*desc=*/[{
|
||||
Return the static size attributes.
|
||||
}],
|
||||
/*retTy=*/"::mlir::ArrayAttr",
|
||||
/*retTy=*/"::llvm::ArrayRef<int64_t>",
|
||||
/*methodName=*/"static_sizes",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
@@ -148,7 +148,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*desc=*/[{
|
||||
Return the dynamic stride attributes.
|
||||
}],
|
||||
/*retTy=*/"::mlir::ArrayAttr",
|
||||
/*retTy=*/"::llvm::ArrayRef<int64_t>",
|
||||
/*methodName=*/"static_strides",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
@@ -165,8 +165,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
Builder b($_op->getContext());
|
||||
return ::mlir::getMixedValues($_op.getStaticOffsets(),
|
||||
$_op.getOffsets());
|
||||
$_op.getOffsets(), b);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@@ -178,7 +179,8 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes());
|
||||
Builder b($_op->getContext());
|
||||
return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes(), b);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@@ -190,8 +192,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
Builder b($_op->getContext());
|
||||
return ::mlir::getMixedValues($_op.getStaticStrides(),
|
||||
$_op.getStrides());
|
||||
$_op.getStrides(), b);
|
||||
}]
|
||||
>,
|
||||
|
||||
@@ -202,9 +205,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*args=*/(ins "unsigned":$idx),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
::llvm::APInt v = *(static_offsets()
|
||||
.template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
|
||||
return ::mlir::ShapedType::isDynamic(v.getSExtValue());
|
||||
return ::mlir::ShapedType::isDynamic(static_offsets()[idx]);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@@ -214,9 +215,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*args=*/(ins "unsigned":$idx),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
::llvm::APInt v = *(static_sizes()
|
||||
.template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
|
||||
return ::mlir::ShapedType::isDynamic(v.getSExtValue());
|
||||
return ::mlir::ShapedType::isDynamic(static_sizes()[idx]);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@@ -226,9 +225,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*args=*/(ins "unsigned":$idx),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
::llvm::APInt v = *(static_strides()
|
||||
.template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
|
||||
return ::mlir::ShapedType::isDynamic(v.getSExtValue());
|
||||
return ::mlir::ShapedType::isDynamic(static_strides()[idx]);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@@ -241,9 +238,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(!$_op.isDynamicOffset(idx) && "expected static offset");
|
||||
::llvm::APInt v = *(static_offsets().
|
||||
template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
|
||||
return v.getSExtValue();
|
||||
return static_offsets()[idx];
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@@ -256,9 +251,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(!$_op.isDynamicSize(idx) && "expected static size");
|
||||
::llvm::APInt v = *(static_sizes().
|
||||
template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
|
||||
return v.getSExtValue();
|
||||
return static_sizes()[idx];
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@@ -271,9 +264,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(!$_op.isDynamicStride(idx) && "expected static stride");
|
||||
::llvm::APInt v = *(static_strides().
|
||||
template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
|
||||
return v.getSExtValue();
|
||||
return static_strides()[idx];
|
||||
}]
|
||||
>,
|
||||
|
||||
@@ -289,7 +280,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*defaultImplementation=*/[{
|
||||
assert($_op.isDynamicOffset(idx) && "expected dynamic offset");
|
||||
auto numDynamic = getNumDynamicEntriesUpToIdx(
|
||||
static_offsets().template cast<::mlir::ArrayAttr>(),
|
||||
static_offsets(),
|
||||
::mlir::ShapedType::isDynamic,
|
||||
idx);
|
||||
return $_op.getOffsetSizeAndStrideStartOperandIndex() + numDynamic;
|
||||
@@ -307,7 +298,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*defaultImplementation=*/[{
|
||||
assert($_op.isDynamicSize(idx) && "expected dynamic size");
|
||||
auto numDynamic = getNumDynamicEntriesUpToIdx(
|
||||
static_sizes().template cast<::mlir::ArrayAttr>(), ::mlir::ShapedType::isDynamic, idx);
|
||||
static_sizes(), ::mlir::ShapedType::isDynamic, idx);
|
||||
return $_op.getOffsetSizeAndStrideStartOperandIndex() +
|
||||
offsets().size() + numDynamic;
|
||||
}]
|
||||
@@ -324,7 +315,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
/*defaultImplementation=*/[{
|
||||
assert($_op.isDynamicStride(idx) && "expected dynamic stride");
|
||||
auto numDynamic = getNumDynamicEntriesUpToIdx(
|
||||
static_strides().template cast<::mlir::ArrayAttr>(),
|
||||
static_strides(),
|
||||
::mlir::ShapedType::isDynamic,
|
||||
idx);
|
||||
return $_op.getOffsetSizeAndStrideStartOperandIndex() +
|
||||
@@ -333,20 +324,20 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Helper method to compute the number of dynamic entries of `attr`, up to
|
||||
Helper method to compute the number of dynamic entries of `staticVals`, up to
|
||||
`idx` using `isDynamic` to determine whether an entry is dynamic.
|
||||
}],
|
||||
/*retTy=*/"unsigned",
|
||||
/*methodName=*/"getNumDynamicEntriesUpToIdx",
|
||||
/*args=*/(ins "::mlir::ArrayAttr":$attr,
|
||||
/*args=*/(ins "::llvm::ArrayRef<int64_t>":$staticVals,
|
||||
"::llvm::function_ref<bool(int64_t)>":$isDynamic,
|
||||
"unsigned":$idx),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return std::count_if(
|
||||
attr.getValue().begin(), attr.getValue().begin() + idx,
|
||||
[&](::mlir::Attribute attr) {
|
||||
return isDynamic(attr.cast<::mlir::IntegerAttr>().getInt());
|
||||
staticVals.begin(), staticVals.begin() + idx,
|
||||
[&](int64_t val) {
|
||||
return isDynamic(val);
|
||||
});
|
||||
}]
|
||||
>,
|
||||
|
||||
@@ -1705,10 +1705,8 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
|
||||
auto viewMemRefType = subViewOp.getType();
|
||||
auto inferredType =
|
||||
memref::SubViewOp::inferResultType(
|
||||
subViewOp.getSourceType(),
|
||||
extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
|
||||
extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
|
||||
extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
|
||||
subViewOp.getSourceType(), subViewOp.getStaticOffsets(),
|
||||
subViewOp.getStaticSizes(), subViewOp.getStaticStrides())
|
||||
.cast<MemRefType>();
|
||||
auto targetElementTy =
|
||||
typeConverter->convertType(viewMemRefType.getElementType());
|
||||
|
||||
@@ -30,8 +30,8 @@ public:
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = sliceOp.getLoc();
|
||||
Value input = sliceOp.getInput();
|
||||
SmallVector<int64_t> strides, sizes;
|
||||
auto starts = sliceOp.getStart();
|
||||
SmallVector<int64_t> strides, sizes, starts;
|
||||
starts = extractFromI64ArrayAttr(sliceOp.getStart());
|
||||
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
|
||||
|
||||
SmallVector<Value> dynSizes;
|
||||
@@ -44,15 +44,15 @@ public:
|
||||
|
||||
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
|
||||
auto offset = rewriter.create<arith::ConstantOp>(
|
||||
loc,
|
||||
rewriter.getIndexAttr(starts[index].cast<IntegerAttr>().getInt()));
|
||||
loc, rewriter.getIndexAttr(starts[index]));
|
||||
dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
|
||||
}
|
||||
|
||||
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
||||
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
|
||||
ValueRange({}), starts, rewriter.getI64ArrayAttr(sizes),
|
||||
rewriter.getI64ArrayAttr(strides));
|
||||
ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
|
||||
rewriter.getDenseI64ArrayAttr(sizes),
|
||||
rewriter.getDenseI64ArrayAttr(strides));
|
||||
|
||||
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
|
||||
return success();
|
||||
|
||||
@@ -40,16 +40,6 @@ static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Extracts a vector of int64_t from an array attribute. Asserts if the
|
||||
/// attribute contains values other than integers.
|
||||
static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
|
||||
SmallVector<int64_t> result;
|
||||
result.reserve(attr.size());
|
||||
for (APInt value : attr.getAsValueRange<IntegerAttr>())
|
||||
result.push_back(value.getSExtValue());
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// A simple pattern rewriter that implements no special logic.
|
||||
class SimpleRewriter : public PatternRewriter {
|
||||
@@ -1205,7 +1195,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
|
||||
DiagnosedSilenceableFailure
|
||||
transform::TileOp::apply(TransformResults &transformResults,
|
||||
TransformState &state) {
|
||||
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
|
||||
ArrayRef<int64_t> tileSizes = getStaticSizes();
|
||||
|
||||
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
|
||||
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
|
||||
@@ -1270,7 +1260,7 @@ transform::TileOp::apply(TransformResults &transformResults,
|
||||
});
|
||||
}
|
||||
|
||||
tilingOptions.setInterchange(extractI64Array(getInterchange()));
|
||||
tilingOptions.setInterchange(getInterchange());
|
||||
SimpleRewriter rewriter(linalgOp.getContext());
|
||||
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
|
||||
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
|
||||
@@ -1298,7 +1288,7 @@ transform::TileOp::apply(TransformResults &transformResults,
|
||||
|
||||
SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
|
||||
ValueRange dynamic = getDynamicSizes();
|
||||
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
|
||||
ArrayRef<int64_t> tileSizes = getStaticSizes();
|
||||
SmallVector<OpFoldResult> results;
|
||||
results.reserve(tileSizes.size());
|
||||
unsigned dynamicPos = 0;
|
||||
@@ -1313,22 +1303,51 @@ SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
|
||||
return results;
|
||||
}
|
||||
|
||||
// We want to parse `DenseI64ArrayAttr` using the short form without the
|
||||
// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
|
||||
ParseResult parseOptionalInterchange(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
if (succeeded(parser.parseOptionalLBrace())) {
|
||||
if (failed(parser.parseKeyword("interchange")))
|
||||
return parser.emitError(parser.getNameLoc()) << "expect `interchange`";
|
||||
if (failed(parser.parseEqual()))
|
||||
return parser.emitError(parser.getNameLoc()) << "expect `=`";
|
||||
result.addAttribute("interchange",
|
||||
DenseI64ArrayAttr::parse(parser, Type{}));
|
||||
if (failed(parser.parseRBrace()))
|
||||
return parser.emitError(parser.getNameLoc()) << "expect `}`";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void printOptionalInterchange(OpAsmPrinter &p,
|
||||
ArrayRef<int64_t> interchangeVals) {
|
||||
if (!interchangeVals.empty()) {
|
||||
p << " {interchange = [";
|
||||
llvm::interleaveComma(interchangeVals, p,
|
||||
[&](int64_t integer) { p << integer; });
|
||||
p << "]}";
|
||||
}
|
||||
}
|
||||
|
||||
ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::UnresolvedOperand target;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
|
||||
ArrayAttr staticSizes;
|
||||
DenseI64ArrayAttr staticSizes;
|
||||
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
|
||||
if (parser.parseOperand(target) ||
|
||||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
|
||||
parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
|
||||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
|
||||
return ParseResult::failure();
|
||||
|
||||
// Parse optional interchange.
|
||||
if (failed(parseOptionalInterchange(parser, result)))
|
||||
return ParseResult::failure();
|
||||
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
|
||||
size_t numExpectedLoops =
|
||||
staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
|
||||
staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
|
||||
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
|
||||
return success();
|
||||
}
|
||||
@@ -1336,7 +1355,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
void TileOp::print(OpAsmPrinter &p) {
|
||||
p << ' ' << getTarget();
|
||||
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
|
||||
printOptionalInterchange(p, getInterchange());
|
||||
}
|
||||
|
||||
void transform::TileOp::getEffects(
|
||||
@@ -1379,13 +1398,13 @@ void transform::TileToForeachThreadOp::build(
|
||||
// bugs ensue.
|
||||
MLIRContext *ctx = builder.getContext();
|
||||
auto operationType = pdl::OperationType::get(ctx);
|
||||
auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
|
||||
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
|
||||
build(builder, result,
|
||||
/*resultTypes=*/TypeRange{operationType, operationType},
|
||||
/*target=*/target,
|
||||
/*num_threads=*/ValueRange{},
|
||||
/*tile_sizes=*/dynamicTileSizes,
|
||||
/*static_num_threads=*/builder.getI64ArrayAttr({}),
|
||||
/*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
|
||||
/*static_tile_sizes=*/staticTileSizesAttr,
|
||||
/*mapping=*/mapping);
|
||||
}
|
||||
@@ -1414,14 +1433,14 @@ void transform::TileToForeachThreadOp::build(
|
||||
// bugs ensue.
|
||||
MLIRContext *ctx = builder.getContext();
|
||||
auto operationType = pdl::OperationType::get(ctx);
|
||||
auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
|
||||
auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
|
||||
build(builder, result,
|
||||
/*resultTypes=*/TypeRange{operationType, operationType},
|
||||
/*target=*/target,
|
||||
/*num_threads=*/dynamicNumThreads,
|
||||
/*tile_sizes=*/ValueRange{},
|
||||
/*static_num_threads=*/staticNumThreadsAttr,
|
||||
/*static_tile_sizes=*/builder.getI64ArrayAttr({}),
|
||||
/*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
|
||||
/*mapping=*/mapping);
|
||||
}
|
||||
|
||||
@@ -1547,11 +1566,13 @@ void transform::TileToForeachThreadOp::getEffects(
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
|
||||
return getMixedValues(getStaticNumThreads(), getNumThreads());
|
||||
Builder b(getContext());
|
||||
return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
|
||||
return getMixedValues(getStaticTileSizes(), getTileSizes());
|
||||
Builder b(getContext());
|
||||
return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
|
||||
}
|
||||
|
||||
LogicalResult TileToForeachThreadOp::verify() {
|
||||
@@ -1567,7 +1588,7 @@ LogicalResult TileToForeachThreadOp::verify() {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::TileToScfForOp::apply(TransformResults &transformResults,
|
||||
TransformState &state) {
|
||||
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
|
||||
ArrayRef<int64_t> tileSizes = getStaticSizes();
|
||||
|
||||
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
|
||||
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
|
||||
@@ -1632,7 +1653,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
|
||||
});
|
||||
}
|
||||
|
||||
tilingOptions.setInterchange(extractI64Array(getInterchange()));
|
||||
tilingOptions.setInterchange(getInterchange());
|
||||
SimpleRewriter rewriter(tilingInterfaceOp.getContext());
|
||||
FailureOr<scf::SCFTilingResult> tilingResult =
|
||||
tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
|
||||
@@ -1655,7 +1676,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
|
||||
|
||||
SmallVector<OpFoldResult> transform::TileToScfForOp::getMixedSizes() {
|
||||
ValueRange dynamic = getDynamicSizes();
|
||||
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
|
||||
ArrayRef<int64_t> tileSizes = getStaticSizes();
|
||||
SmallVector<OpFoldResult> results;
|
||||
results.reserve(tileSizes.size());
|
||||
unsigned dynamicPos = 0;
|
||||
@@ -1674,18 +1695,20 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::UnresolvedOperand target;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
|
||||
ArrayAttr staticSizes;
|
||||
DenseI64ArrayAttr staticSizes;
|
||||
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
|
||||
if (parser.parseOperand(target) ||
|
||||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
|
||||
parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
|
||||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
|
||||
return ParseResult::failure();
|
||||
|
||||
// Parse optional interchange.
|
||||
if (failed(parseOptionalInterchange(parser, result)))
|
||||
return ParseResult::failure();
|
||||
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
|
||||
size_t numExpectedLoops =
|
||||
staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
|
||||
staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
|
||||
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
|
||||
return success();
|
||||
}
|
||||
@@ -1693,7 +1716,7 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
|
||||
void TileToScfForOp::print(OpAsmPrinter &p) {
|
||||
p << ' ' << getTarget();
|
||||
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
|
||||
printOptionalInterchange(p, getInterchange());
|
||||
}
|
||||
|
||||
void transform::TileToScfForOp::getEffects(
|
||||
|
||||
@@ -348,7 +348,7 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
|
||||
SmallVector<AffineExpr, 4> outputExprs;
|
||||
for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
|
||||
outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
|
||||
padOp.getStaticLow()[i].cast<IntegerAttr>().getInt());
|
||||
padOp.getStaticLow()[i]);
|
||||
}
|
||||
|
||||
SmallVector<AffineMap, 2> transferMaps = {
|
||||
|
||||
@@ -1776,8 +1776,9 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
|
||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
|
||||
ShapedType::kDynamic);
|
||||
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
|
||||
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
|
||||
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
|
||||
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
|
||||
b.getDenseI64ArrayAttr(staticSizes),
|
||||
b.getDenseI64ArrayAttr(staticStrides));
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
@@ -1823,8 +1824,8 @@ LogicalResult ReinterpretCastOp::verify() {
|
||||
<< srcType << " and result memref type " << resultType;
|
||||
|
||||
// Match sizes in result memref type and in static_sizes attribute.
|
||||
for (auto &en : llvm::enumerate(llvm::zip(
|
||||
resultType.getShape(), extractFromI64ArrayAttr(getStaticSizes())))) {
|
||||
for (auto &en :
|
||||
llvm::enumerate(llvm::zip(resultType.getShape(), getStaticSizes()))) {
|
||||
int64_t resultSize = std::get<0>(en.value());
|
||||
int64_t expectedSize = std::get<1>(en.value());
|
||||
if (!ShapedType::isDynamic(resultSize) &&
|
||||
@@ -1844,7 +1845,7 @@ LogicalResult ReinterpretCastOp::verify() {
|
||||
<< resultType;
|
||||
|
||||
// Match offset in result memref type and in static_offsets attribute.
|
||||
int64_t expectedOffset = extractFromI64ArrayAttr(getStaticOffsets()).front();
|
||||
int64_t expectedOffset = getStaticOffsets().front();
|
||||
if (!ShapedType::isDynamic(resultOffset) &&
|
||||
!ShapedType::isDynamic(expectedOffset) &&
|
||||
resultOffset != expectedOffset)
|
||||
@@ -1852,8 +1853,8 @@ LogicalResult ReinterpretCastOp::verify() {
|
||||
<< resultOffset << " instead of " << expectedOffset;
|
||||
|
||||
// Match strides in result memref type and in static_strides attribute.
|
||||
for (auto &en : llvm::enumerate(llvm::zip(
|
||||
resultStrides, extractFromI64ArrayAttr(getStaticStrides())))) {
|
||||
for (auto &en :
|
||||
llvm::enumerate(llvm::zip(resultStrides, getStaticStrides()))) {
|
||||
int64_t resultStride = std::get<0>(en.value());
|
||||
int64_t expectedStride = std::get<1>(en.value());
|
||||
if (!ShapedType::isDynamic(resultStride) &&
|
||||
@@ -2665,8 +2666,9 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
|
||||
.cast<MemRefType>();
|
||||
}
|
||||
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
|
||||
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
|
||||
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
|
||||
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
|
||||
b.getDenseI64ArrayAttr(staticSizes),
|
||||
b.getDenseI64ArrayAttr(staticStrides));
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
@@ -2831,9 +2833,7 @@ LogicalResult SubViewOp::verify() {
|
||||
|
||||
// Verify result type against inferred type.
|
||||
auto expectedType = SubViewOp::inferResultType(
|
||||
baseType, extractFromI64ArrayAttr(getStaticOffsets()),
|
||||
extractFromI64ArrayAttr(getStaticSizes()),
|
||||
extractFromI64ArrayAttr(getStaticStrides()));
|
||||
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
|
||||
|
||||
auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
|
||||
subViewType, getMixedSizes());
|
||||
|
||||
@@ -45,9 +45,8 @@ static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
|
||||
builder.setInsertionPoint(subviewUse);
|
||||
Type newType = memref::SubViewOp::inferRankReducedResultType(
|
||||
subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
|
||||
extractFromI64ArrayAttr(subviewUse.getStaticOffsets()),
|
||||
extractFromI64ArrayAttr(subviewUse.getStaticSizes()),
|
||||
extractFromI64ArrayAttr(subviewUse.getStaticStrides()));
|
||||
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
|
||||
subviewUse.getStaticStrides());
|
||||
Value newSubview = builder.create<memref::SubViewOp>(
|
||||
subviewUse->getLoc(), newType.cast<MemRefType>(), val,
|
||||
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
|
||||
|
||||
@@ -337,8 +337,7 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
|
||||
|
||||
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
|
||||
auto dimMask = computeRankReductionMask(
|
||||
extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
|
||||
extractOperand.getType().getShape());
|
||||
extractOperand.getStaticSizes(), extractOperand.getType().getShape());
|
||||
size_t dimIndex = 0;
|
||||
for (size_t i = 0, e = sizes.size(); i < e; i++) {
|
||||
if (dimMask && dimMask->count(i))
|
||||
@@ -1713,8 +1712,9 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
|
||||
.cast<RankedTensorType>();
|
||||
}
|
||||
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
|
||||
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
|
||||
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
|
||||
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
|
||||
b.getDenseI64ArrayAttr(staticSizes),
|
||||
b.getDenseI64ArrayAttr(staticStrides));
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
@@ -1949,13 +1949,13 @@ public:
|
||||
return failure();
|
||||
|
||||
// Check if there are any dynamic parts, which are not supported.
|
||||
auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
|
||||
auto offsets = op.getStaticOffsets();
|
||||
if (llvm::is_contained(offsets, ShapedType::kDynamic))
|
||||
return failure();
|
||||
auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
|
||||
auto sizes = op.getStaticSizes();
|
||||
if (llvm::is_contained(sizes, ShapedType::kDynamic))
|
||||
return failure();
|
||||
auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
|
||||
auto strides = op.getStaticStrides();
|
||||
if (llvm::is_contained(strides, ShapedType::kDynamic))
|
||||
return failure();
|
||||
|
||||
@@ -2124,8 +2124,9 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
|
||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
|
||||
ShapedType::kDynamic);
|
||||
build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
|
||||
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
|
||||
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
|
||||
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
|
||||
b.getDenseI64ArrayAttr(staticSizes),
|
||||
b.getDenseI64ArrayAttr(staticStrides));
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
@@ -2153,17 +2154,14 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
|
||||
|
||||
/// Rank-reducing type verification for both InsertSliceOp and
|
||||
/// ParallelInsertSliceOp.
|
||||
static SliceVerificationResult
|
||||
verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
|
||||
ArrayAttr staticOffsets, ArrayAttr staticSizes,
|
||||
ArrayAttr staticStrides,
|
||||
ShapedType *expectedType = nullptr) {
|
||||
static SliceVerificationResult verifyInsertSliceOp(
|
||||
ShapedType srcType, ShapedType dstType, ArrayRef<int64_t> staticOffsets,
|
||||
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
|
||||
ShapedType *expectedType = nullptr) {
|
||||
// insert_slice is the inverse of extract_slice, use the same type
|
||||
// inference.
|
||||
RankedTensorType expected = ExtractSliceOp::inferResultType(
|
||||
dstType, extractFromI64ArrayAttr(staticOffsets),
|
||||
extractFromI64ArrayAttr(staticSizes),
|
||||
extractFromI64ArrayAttr(staticStrides));
|
||||
dstType, staticOffsets, staticSizes, staticStrides);
|
||||
if (expectedType)
|
||||
*expectedType = expected;
|
||||
return isRankReducedType(expected, srcType);
|
||||
@@ -2482,9 +2480,8 @@ ParseResult parseInferType(OpAsmParser &parser,
|
||||
LogicalResult PadOp::verify() {
|
||||
auto sourceType = getSource().getType().cast<RankedTensorType>();
|
||||
auto resultType = getResult().getType().cast<RankedTensorType>();
|
||||
auto expectedType = PadOp::inferResultType(
|
||||
sourceType, extractFromI64ArrayAttr(getStaticLow()),
|
||||
extractFromI64ArrayAttr(getStaticHigh()));
|
||||
auto expectedType =
|
||||
PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
|
||||
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
|
||||
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
|
||||
continue;
|
||||
@@ -2556,8 +2553,9 @@ void PadOp::build(OpBuilder &b, OperationState &result, Value source,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
auto sourceType = source.getType().cast<RankedTensorType>();
|
||||
auto resultType = inferResultType(sourceType, staticLow, staticHigh);
|
||||
build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
|
||||
b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
|
||||
build(b, result, resultType, source, low, high,
|
||||
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
|
||||
nofold ? b.getUnitAttr() : UnitAttr());
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
@@ -2591,7 +2589,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
|
||||
}
|
||||
assert(resultType.isa<RankedTensorType>());
|
||||
build(b, result, resultType, source, dynamicLow, dynamicHigh,
|
||||
b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
|
||||
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
|
||||
nofold ? b.getUnitAttr() : UnitAttr());
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
@@ -2658,8 +2656,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
|
||||
|
||||
auto newResultType = PadOp::inferResultType(
|
||||
castOp.getSource().getType().cast<RankedTensorType>(),
|
||||
extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
|
||||
extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
|
||||
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
|
||||
padTensorOp.getResultType().getShape());
|
||||
|
||||
if (newResultType == padTensorOp.getResultType()) {
|
||||
@@ -2940,8 +2937,9 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
|
||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
|
||||
ShapedType::kDynamic);
|
||||
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
|
||||
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
|
||||
b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
|
||||
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
|
||||
b.getDenseI64ArrayAttr(staticSizes),
|
||||
b.getDenseI64ArrayAttr(staticStrides));
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
@@ -3086,12 +3084,12 @@ template <typename OpTy>
|
||||
static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
|
||||
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
||||
"applies to only pack or unpack operations");
|
||||
Builder builder(op);
|
||||
SmallVector<OpFoldResult> mixedInnerTiles;
|
||||
unsigned dynamicValIndex = 0;
|
||||
for (Attribute attr : op.getStaticInnerTiles()) {
|
||||
auto tileAttr = attr.cast<IntegerAttr>();
|
||||
if (!ShapedType::isDynamic(tileAttr.getInt()))
|
||||
mixedInnerTiles.push_back(tileAttr);
|
||||
for (int64_t staticTile : op.getStaticInnerTiles()) {
|
||||
if (!ShapedType::isDynamic(staticTile))
|
||||
mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
|
||||
else
|
||||
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
|
||||
}
|
||||
|
||||
@@ -137,4 +137,41 @@ SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
|
||||
return getValueOrCreateConstantIndexOp(b, loc, value);
|
||||
}));
|
||||
}
|
||||
|
||||
/// Return a vector of OpFoldResults with the same size a staticValues, but all
|
||||
/// elements for which ShapedType::isDynamic is true, will be replaced by
|
||||
/// dynamicValues.
|
||||
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
|
||||
ValueRange dynamicValues, Builder &b) {
|
||||
SmallVector<OpFoldResult> res;
|
||||
res.reserve(staticValues.size());
|
||||
unsigned numDynamic = 0;
|
||||
unsigned count = static_cast<unsigned>(staticValues.size());
|
||||
for (unsigned idx = 0; idx < count; ++idx) {
|
||||
int64_t value = staticValues[idx];
|
||||
res.push_back(ShapedType::isDynamic(value)
|
||||
? OpFoldResult{dynamicValues[numDynamic++]}
|
||||
: OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Decompose a vector of mixed static or dynamic values into the corresponding
|
||||
/// pair of arrays. This is the inverse function of `getMixedValues`.
|
||||
std::pair<ArrayAttr, SmallVector<Value>>
|
||||
decomposeMixedValues(Builder &b,
|
||||
const SmallVectorImpl<OpFoldResult> &mixedValues) {
|
||||
SmallVector<int64_t> staticValues;
|
||||
SmallVector<Value> dynamicValues;
|
||||
for (const auto &it : mixedValues) {
|
||||
if (it.is<Attribute>()) {
|
||||
staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
|
||||
} else {
|
||||
staticValues.push_back(ShapedType::kDynamic);
|
||||
dynamicValues.push_back(it.get<Value>());
|
||||
}
|
||||
}
|
||||
return {b.getI64ArrayAttr(staticValues), dynamicValues};
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -20,15 +20,15 @@ using namespace mlir;
|
||||
LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
|
||||
StringRef name,
|
||||
unsigned numElements,
|
||||
ArrayAttr attr,
|
||||
ArrayRef<int64_t> staticVals,
|
||||
ValueRange values) {
|
||||
/// Check static and dynamic offsets/sizes/strides does not overflow type.
|
||||
if (attr.size() != numElements)
|
||||
// Check static and dynamic offsets/sizes/strides does not overflow type.
|
||||
if (staticVals.size() != numElements)
|
||||
return op->emitError("expected ")
|
||||
<< numElements << " " << name << " values";
|
||||
unsigned expectedNumDynamicEntries =
|
||||
llvm::count_if(attr.getValue(), [&](Attribute attr) {
|
||||
return ShapedType::isDynamic(attr.cast<IntegerAttr>().getInt());
|
||||
llvm::count_if(staticVals, [&](int64_t staticVal) {
|
||||
return ShapedType::isDynamic(staticVal);
|
||||
});
|
||||
if (values.size() != expectedNumDynamicEntries)
|
||||
return op->emitError("expected ")
|
||||
@@ -70,19 +70,19 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
|
||||
}
|
||||
|
||||
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
OperandRange values, ArrayAttr integers) {
|
||||
OperandRange values,
|
||||
ArrayRef<int64_t> integers) {
|
||||
printer << '[';
|
||||
if (integers.empty()) {
|
||||
printer << "]";
|
||||
return;
|
||||
}
|
||||
unsigned idx = 0;
|
||||
llvm::interleaveComma(integers, printer, [&](Attribute a) {
|
||||
int64_t val = a.cast<IntegerAttr>().getInt();
|
||||
if (ShapedType::isDynamic(val))
|
||||
llvm::interleaveComma(integers, printer, [&](int64_t integer) {
|
||||
if (ShapedType::isDynamic(integer))
|
||||
printer << values[idx++];
|
||||
else
|
||||
printer << val;
|
||||
printer << integer;
|
||||
});
|
||||
printer << ']';
|
||||
}
|
||||
@@ -90,28 +90,28 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
ParseResult mlir::parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
ArrayAttr &integers) {
|
||||
DenseI64ArrayAttr &integers) {
|
||||
if (failed(parser.parseLSquare()))
|
||||
return failure();
|
||||
// 0-D.
|
||||
if (succeeded(parser.parseOptionalRSquare())) {
|
||||
integers = parser.getBuilder().getArrayAttr({});
|
||||
integers = parser.getBuilder().getDenseI64ArrayAttr({});
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> attrVals;
|
||||
SmallVector<int64_t, 4> integerVals;
|
||||
while (true) {
|
||||
OpAsmParser::UnresolvedOperand operand;
|
||||
auto res = parser.parseOptionalOperand(operand);
|
||||
if (res.has_value() && succeeded(res.value())) {
|
||||
values.push_back(operand);
|
||||
attrVals.push_back(ShapedType::kDynamic);
|
||||
integerVals.push_back(ShapedType::kDynamic);
|
||||
} else {
|
||||
IntegerAttr attr;
|
||||
if (failed(parser.parseAttribute<IntegerAttr>(attr)))
|
||||
int64_t integer;
|
||||
if (failed(parser.parseInteger(integer)))
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< "expected SSA value or integer";
|
||||
attrVals.push_back(attr.getInt());
|
||||
integerVals.push_back(integer);
|
||||
}
|
||||
|
||||
if (succeeded(parser.parseOptionalComma()))
|
||||
@@ -120,7 +120,7 @@ ParseResult mlir::parseDynamicIndexList(
|
||||
return failure();
|
||||
break;
|
||||
}
|
||||
integers = parser.getBuilder().getI64ArrayAttr(attrVals);
|
||||
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -144,34 +144,3 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult, 4> mlir::getMixedValues(ArrayAttr staticValues,
|
||||
ValueRange dynamicValues) {
|
||||
SmallVector<OpFoldResult, 4> res;
|
||||
res.reserve(staticValues.size());
|
||||
unsigned numDynamic = 0;
|
||||
unsigned count = static_cast<unsigned>(staticValues.size());
|
||||
for (unsigned idx = 0; idx < count; ++idx) {
|
||||
APInt value = staticValues[idx].cast<IntegerAttr>().getValue();
|
||||
res.push_back(ShapedType::isDynamic(value.getSExtValue())
|
||||
? OpFoldResult{dynamicValues[numDynamic++]}
|
||||
: OpFoldResult{staticValues[idx]});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::pair<ArrayAttr, SmallVector<Value>>
|
||||
mlir::decomposeMixedValues(Builder &b,
|
||||
const SmallVectorImpl<OpFoldResult> &mixedValues) {
|
||||
SmallVector<int64_t> staticValues;
|
||||
SmallVector<Value> dynamicValues;
|
||||
for (const auto &it : mixedValues) {
|
||||
if (it.is<Attribute>()) {
|
||||
staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
|
||||
} else {
|
||||
staticValues.push_back(ShapedType::kDynamic);
|
||||
dynamicValues.push_back(it.get<Value>());
|
||||
}
|
||||
}
|
||||
return {b.getI64ArrayAttr(staticValues), dynamicValues};
|
||||
}
|
||||
|
||||
@@ -49,6 +49,15 @@ def _get_int_array_attr(
|
||||
|
||||
return ArrayAttr.get([_get_int64_attr(v) for v in values])
|
||||
|
||||
def _get_dense_int64_array_attr(
|
||||
values: Sequence[int]) -> DenseI64ArrayAttr:
|
||||
"""Creates a dense integer array from a sequence of integers.
|
||||
Expects the thread-local MLIR context to have been set by the context
|
||||
manager.
|
||||
"""
|
||||
if values is None:
|
||||
return DenseI64ArrayAttr.get([])
|
||||
return DenseI64ArrayAttr.get(values)
|
||||
|
||||
def _get_int_int_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
|
||||
@@ -250,14 +259,11 @@ class TileOp:
|
||||
else:
|
||||
for size in sizes:
|
||||
if isinstance(size, int):
|
||||
static_sizes.append(IntegerAttr.get(i64_type, size))
|
||||
elif isinstance(size, IntegerAttr):
|
||||
static_sizes.append(size)
|
||||
else:
|
||||
static_sizes.append(
|
||||
IntegerAttr.get(i64_type, ShapedType.get_dynamic_size()))
|
||||
static_sizes.append(ShapedType.get_dynamic_size())
|
||||
dynamic_sizes.append(_get_op_result_or_value(size))
|
||||
sizes_attr = ArrayAttr.get(static_sizes)
|
||||
sizes_attr = DenseI64ArrayAttr.get(static_sizes)
|
||||
|
||||
num_loops = sum(
|
||||
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
|
||||
@@ -266,14 +272,14 @@ class TileOp:
|
||||
_get_op_result_or_value(target),
|
||||
dynamic_sizes=dynamic_sizes,
|
||||
static_sizes=sizes_attr,
|
||||
interchange=_get_int_array_attr(interchange) if interchange else None,
|
||||
interchange=_get_dense_int64_array_attr(interchange) if interchange else None,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]:
|
||||
def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
|
||||
if not attr:
|
||||
return []
|
||||
return [IntegerAttr(element).value for element in attr]
|
||||
return [element for element in attr]
|
||||
|
||||
|
||||
class VectorizeOp:
|
||||
|
||||
@@ -138,7 +138,7 @@ func.func @permute_generic(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
|
||||
transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]}
|
||||
transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @permute_generic
|
||||
@@ -191,8 +191,8 @@ func.func @matmul_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
|
||||
%1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange=[1, 2, 0]}
|
||||
%2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange=[1, 0, 2]}
|
||||
%1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]}
|
||||
%2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]}
|
||||
%3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40]
|
||||
}
|
||||
|
||||
|
||||
@@ -108,7 +108,6 @@ def testSplit():
|
||||
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
|
||||
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
|
||||
|
||||
|
||||
@run
|
||||
def testTileCompact():
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
|
||||
@@ -120,14 +119,11 @@ def testTileCompact():
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
|
||||
# CHECK: interchange = [0, 1]
|
||||
|
||||
|
||||
@run
|
||||
def testTileAttributes():
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
|
||||
attr = ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
|
||||
ichange = ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
|
||||
attr = DenseI64ArrayAttr.get([4, 8])
|
||||
ichange = DenseI64ArrayAttr.get([0, 1])
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
|
||||
transform.YieldOp()
|
||||
@@ -136,7 +132,6 @@ def testTileAttributes():
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
|
||||
# CHECK: interchange = [0, 1]
|
||||
|
||||
|
||||
@run
|
||||
def testTileZero():
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
|
||||
@@ -149,7 +144,6 @@ def testTileZero():
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
|
||||
# CHECK: interchange = [0, 1, 2, 3]
|
||||
|
||||
|
||||
@run
|
||||
def testTileDynamic():
|
||||
with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get())
|
||||
|
||||
Reference in New Issue
Block a user