[mlir][vector] Pattern to clean up vector.extract during distribution

This prevents blocking propagation when converting between scalar and
vector<1>

Differential Revision: https://reviews.llvm.org/D129782
This commit is contained in:
Thomas Raoux
2022-07-14 15:34:22 +00:00
parent 0e718443c7
commit f48ce52c4c
2 changed files with 47 additions and 2 deletions

View File

@@ -719,6 +719,33 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
/// Pattern to move out vector.extract of single element vector. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = getWarpResult(
warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
if (extractOp.getVectorType().getNumElements() != 1)
return failure();
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value newExtract = rewriter.create<vector::ExtractOp>(
loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
return success();
}
};
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't change
/// the order of execution. This creates a new scf.for region after the
@@ -915,8 +942,8 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp,
WarpOpConstant>(patterns.getContext());
WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
}
void mlir::vector::populateDistributeReduction(