mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[mlir][vector] Add folder for ExtractStridedSliceOp
Add folder for the case where ExtractStridedSliceOp source comes from a chain of InsertStridedSliceOp. Also add a folder for the trivial case where the ExtractStridedSliceOp is a no-op. Differential Revision: https://reviews.llvm.org/D89850
This commit is contained in:
@@ -1016,6 +1016,7 @@ def Vector_ExtractStridedSliceOp :
|
||||
void getOffsets(SmallVectorImpl<int64_t> &results);
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
|
||||
}
|
||||
|
||||
|
||||
@@ -1629,6 +1629,81 @@ static LogicalResult verify(ExtractStridedSliceOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// When the source of ExtractStrided comes from a chain of InsertStrided ops try
|
||||
// to use the source o the InsertStrided ops if we can detect that the extracted
|
||||
// vector is a subset of one of the vector inserted.
|
||||
static LogicalResult
|
||||
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
|
||||
// Helper to extract integer out of ArrayAttr.
|
||||
auto getElement = [](ArrayAttr array, int idx) {
|
||||
return array[idx].cast<IntegerAttr>().getInt();
|
||||
};
|
||||
ArrayAttr extractOffsets = op.offsets();
|
||||
ArrayAttr extractStrides = op.strides();
|
||||
ArrayAttr extractSizes = op.sizes();
|
||||
auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
|
||||
while (insertOp) {
|
||||
if (op.getVectorType().getRank() !=
|
||||
insertOp.getSourceVectorType().getRank())
|
||||
return failure();
|
||||
ArrayAttr insertOffsets = insertOp.offsets();
|
||||
ArrayAttr insertStrides = insertOp.strides();
|
||||
// If the rank of extract is greater than the rank of insert, we are likely
|
||||
// extracting a partial chunk of the vector inserted.
|
||||
if (extractOffsets.size() > insertOffsets.size())
|
||||
return failure();
|
||||
bool patialoverlap = false;
|
||||
bool disjoint = false;
|
||||
SmallVector<int64_t, 4> offsetDiffs;
|
||||
for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
|
||||
if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
|
||||
return failure();
|
||||
int64_t start = getElement(insertOffsets, dim);
|
||||
int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
|
||||
int64_t offset = getElement(extractOffsets, dim);
|
||||
int64_t size = getElement(extractSizes, dim);
|
||||
// Check if the start of the extract offset is in the interval inserted.
|
||||
if (start <= offset && offset < end) {
|
||||
// If the extract interval overlaps but is not fully included we may
|
||||
// have a partial overlap that will prevent any folding.
|
||||
if (offset + size > end)
|
||||
patialoverlap = true;
|
||||
offsetDiffs.push_back(offset - start);
|
||||
continue;
|
||||
}
|
||||
disjoint = true;
|
||||
break;
|
||||
}
|
||||
// The extract element chunk is a subset of the insert element.
|
||||
if (!disjoint && !patialoverlap) {
|
||||
op.setOperand(insertOp.source());
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(op.getContext());
|
||||
op.setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
|
||||
b.getI64ArrayAttr(offsetDiffs));
|
||||
return success();
|
||||
}
|
||||
// If the chunk extracted is disjoint from the chunk inserted, keep looking
|
||||
// in the insert chain.
|
||||
if (disjoint)
|
||||
insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
|
||||
else {
|
||||
// The extracted vector partially overlap the inserted vector, we cannot
|
||||
// fold.
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getVectorType() == getResult().getType())
|
||||
return vector();
|
||||
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
|
||||
return getResult();
|
||||
return {};
|
||||
}
|
||||
|
||||
void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
|
||||
populateFromInt64AttrArray(offsets(), results);
|
||||
}
|
||||
|
||||
@@ -90,6 +90,95 @@ func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_strided_fold
|
||||
// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
|
||||
// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>
|
||||
func @extract_strided_fold(%arg : vector<4x3xi1>) -> (vector<4x3xi1>) {
|
||||
%0 = vector.extract_strided_slice %arg
|
||||
{offsets = [0, 0], sizes = [4, 3], strides = [1, 1]}
|
||||
: vector<4x3xi1> to vector<4x3xi1>
|
||||
return %0 : vector<4x3xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_strided_fold_insert
|
||||
// CHECK-SAME: (%[[ARG:.*]]: vector<4x4xf32>
|
||||
// CHECK-NEXT: return %[[ARG]] : vector<4x4xf32>
|
||||
func @extract_strided_fold_insert(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
|
||||
-> (vector<4x4xf32>) {
|
||||
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
|
||||
: vector<4x4xf32> into vector<8x16xf32>
|
||||
%1 = vector.extract_strided_slice %0
|
||||
{offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
|
||||
: vector<8x16xf32> to vector<4x4xf32>
|
||||
return %1 : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Case where the vector inserted is a subset of the vector extracted.
|
||||
// CHECK-LABEL: extract_strided_fold_insert
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<6x4xf32>
|
||||
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG0]]
|
||||
// CHECK-SAME: {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]}
|
||||
// CHECK-SAME: : vector<6x4xf32> to vector<4x4xf32>
|
||||
// CHECK-NEXT: return %[[EXT]] : vector<4x4xf32>
|
||||
func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>)
|
||||
-> (vector<4x4xf32>) {
|
||||
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
|
||||
: vector<6x4xf32> into vector<8x16xf32>
|
||||
%1 = vector.extract_strided_slice %0
|
||||
{offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
|
||||
: vector<8x16xf32> to vector<4x4xf32>
|
||||
return %1 : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Negative test where the extract is not a subset of the element inserted.
|
||||
// CHECK-LABEL: extract_strided_fold_negative
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
|
||||
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
|
||||
// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
|
||||
// CHECK-SAME: : vector<4x4xf32> into vector<8x16xf32>
|
||||
// CHECK: %[[EXT:.*]] = vector.extract_strided_slice %[[INS]]
|
||||
// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
|
||||
// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
|
||||
// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
|
||||
func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
|
||||
-> (vector<6x4xf32>) {
|
||||
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
|
||||
: vector<4x4xf32> into vector<8x16xf32>
|
||||
%1 = vector.extract_strided_slice %0
|
||||
{offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
|
||||
: vector<8x16xf32> to vector<6x4xf32>
|
||||
return %1 : vector<6x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Case where we need to go through 2 level of insert element.
|
||||
// CHECK-LABEL: extract_strided_fold_insert
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
|
||||
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
|
||||
// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
|
||||
// CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32>
|
||||
// CHECK-NEXT: return %[[EXT]] : vector<1x1xf32>
|
||||
func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
|
||||
%c : vector<1x4xf32>) -> (vector<1x1xf32>) {
|
||||
%0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
|
||||
: vector<1x4xf32> into vector<2x4xf32>
|
||||
%1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
|
||||
: vector<1x4xf32> into vector<2x4xf32>
|
||||
%2 = vector.extract_strided_slice %1
|
||||
{offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
|
||||
: vector<2x4xf32> to vector<1x1xf32>
|
||||
return %2 : vector<1x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: transpose_1D_identity
|
||||
// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
|
||||
func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
|
||||
|
||||
Reference in New Issue
Block a user