[mlir][vector] Add folder for ExtractStridedSliceOp

Add folder for the case where ExtractStridedSliceOp source comes from a chain
of InsertStridedSliceOp. Also add a folder for the trivial case where the
ExtractStridedSliceOp is a no-op.

Differential Revision: https://reviews.llvm.org/D89850
This commit is contained in:
Thomas Raoux
2020-10-23 12:07:25 -07:00
parent bfb04aeb85
commit ea6a60a9a6
3 changed files with 165 additions and 0 deletions

View File

@@ -1016,6 +1016,7 @@ def Vector_ExtractStridedSliceOp :
void getOffsets(SmallVectorImpl<int64_t> &results);
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
}

View File

@@ -1629,6 +1629,81 @@ static LogicalResult verify(ExtractStridedSliceOp op) {
return success();
}
// When the source of ExtractStrided comes from a chain of InsertStrided ops try
// to use the source o the InsertStrided ops if we can detect that the extracted
// vector is a subset of one of the vector inserted.
static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
// Helper to extract integer out of ArrayAttr.
auto getElement = [](ArrayAttr array, int idx) {
return array[idx].cast<IntegerAttr>().getInt();
};
ArrayAttr extractOffsets = op.offsets();
ArrayAttr extractStrides = op.strides();
ArrayAttr extractSizes = op.sizes();
auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
while (insertOp) {
if (op.getVectorType().getRank() !=
insertOp.getSourceVectorType().getRank())
return failure();
ArrayAttr insertOffsets = insertOp.offsets();
ArrayAttr insertStrides = insertOp.strides();
// If the rank of extract is greater than the rank of insert, we are likely
// extracting a partial chunk of the vector inserted.
if (extractOffsets.size() > insertOffsets.size())
return failure();
bool patialoverlap = false;
bool disjoint = false;
SmallVector<int64_t, 4> offsetDiffs;
for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
return failure();
int64_t start = getElement(insertOffsets, dim);
int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
int64_t offset = getElement(extractOffsets, dim);
int64_t size = getElement(extractSizes, dim);
// Check if the start of the extract offset is in the interval inserted.
if (start <= offset && offset < end) {
// If the extract interval overlaps but is not fully included we may
// have a partial overlap that will prevent any folding.
if (offset + size > end)
patialoverlap = true;
offsetDiffs.push_back(offset - start);
continue;
}
disjoint = true;
break;
}
// The extract element chunk is a subset of the insert element.
if (!disjoint && !patialoverlap) {
op.setOperand(insertOp.source());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
op.setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
b.getI64ArrayAttr(offsetDiffs));
return success();
}
// If the chunk extracted is disjoint from the chunk inserted, keep looking
// in the insert chain.
if (disjoint)
insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
else {
// The extracted vector partially overlap the inserted vector, we cannot
// fold.
return failure();
}
}
return failure();
}
OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
if (getVectorType() == getResult().getType())
return vector();
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
return getResult();
return {};
}
void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
populateFromInt64AttrArray(offsets(), results);
}

View File

@@ -90,6 +90,95 @@ func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
// -----
// CHECK-LABEL: extract_strided_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>
func @extract_strided_fold(%arg : vector<4x3xi1>) -> (vector<4x3xi1>) {
%0 = vector.extract_strided_slice %arg
{offsets = [0, 0], sizes = [4, 3], strides = [1, 1]}
: vector<4x3xi1> to vector<4x3xi1>
return %0 : vector<4x3xi1>
}
// -----
// CHECK-LABEL: extract_strided_fold_insert
// CHECK-SAME: (%[[ARG:.*]]: vector<4x4xf32>
// CHECK-NEXT: return %[[ARG]] : vector<4x4xf32>
func @extract_strided_fold_insert(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
-> (vector<4x4xf32>) {
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
%1 = vector.extract_strided_slice %0
{offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
: vector<8x16xf32> to vector<4x4xf32>
return %1 : vector<4x4xf32>
}
// -----
// Case where the vector inserted is a subset of the vector extracted.
// CHECK-LABEL: extract_strided_fold_insert
// CHECK-SAME: (%[[ARG0:.*]]: vector<6x4xf32>
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]}
// CHECK-SAME: : vector<6x4xf32> to vector<4x4xf32>
// CHECK-NEXT: return %[[EXT]] : vector<4x4xf32>
func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>)
-> (vector<4x4xf32>) {
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
: vector<6x4xf32> into vector<8x16xf32>
%1 = vector.extract_strided_slice %0
{offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
: vector<8x16xf32> to vector<4x4xf32>
return %1 : vector<4x4xf32>
}
// -----
// Negative test where the extract is not a subset of the element inserted.
// CHECK-LABEL: extract_strided_fold_negative
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
// CHECK-SAME: : vector<4x4xf32> into vector<8x16xf32>
// CHECK: %[[EXT:.*]] = vector.extract_strided_slice %[[INS]]
// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
-> (vector<6x4xf32>) {
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
%1 = vector.extract_strided_slice %0
{offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
: vector<8x16xf32> to vector<6x4xf32>
return %1 : vector<6x4xf32>
}
// -----
// Case where we need to go through 2 level of insert element.
// CHECK-LABEL: extract_strided_fold_insert
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
// CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32>
// CHECK-NEXT: return %[[EXT]] : vector<1x1xf32>
func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
%c : vector<1x4xf32>) -> (vector<1x1xf32>) {
%0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
: vector<1x4xf32> into vector<2x4xf32>
%1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
: vector<1x4xf32> into vector<2x4xf32>
%2 = vector.extract_strided_slice %1
{offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
: vector<2x4xf32> to vector<1x1xf32>
return %2 : vector<1x1xf32>
}
// -----
// CHECK-LABEL: transpose_1D_identity
// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {