mirror of
https://github.com/intel/llvm.git
synced 2026-02-03 10:39:35 +08:00
[mlir][vector] Pattern to clean up vector.extract during distribution
This prevents blocking propagation when converting between scalar and vector<1> Differential Revision: https://reviews.llvm.org/D129782
This commit is contained in:
@@ -719,6 +719,33 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to move out vector.extract of single element vector. Those don't
|
||||
/// need to be distributed and can just be propagated outside of the region.
|
||||
struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
OpOperand *operand = getWarpResult(
|
||||
warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
|
||||
if (!operand)
|
||||
return failure();
|
||||
unsigned int operandNumber = operand->getOperandNumber();
|
||||
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
|
||||
if (extractOp.getVectorType().getNumElements() != 1)
|
||||
return failure();
|
||||
Location loc = extractOp.getLoc();
|
||||
SmallVector<size_t> newRetIndices;
|
||||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
|
||||
rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
|
||||
newRetIndices);
|
||||
rewriter.setInsertionPointAfter(newWarpOp);
|
||||
Value newExtract = rewriter.create<vector::ExtractOp>(
|
||||
loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
|
||||
newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
|
||||
/// the scf.ForOp is the last operation in the region so that it doesn't change
|
||||
/// the order of execution. This creates a new scf.for region after the
|
||||
@@ -915,8 +942,8 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
|
||||
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
|
||||
WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp,
|
||||
WarpOpConstant>(patterns.getContext());
|
||||
WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
|
||||
WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateDistributeReduction(
|
||||
|
||||
Reference in New Issue
Block a user