mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
[mlir][Vector] Mostly-NFC - Restructure options for lowering to LLVM Matrix Intrinsics
Summary: This revision restructures the calling of vector transforms to make it more flexible to ask for lowering through LLVM matrix intrinsics. This also makes sure we bail out in degenerate cases (i.e. 1) in which LLVM complains about not being able to scalarize. Differential Revision: https://reviews.llvm.org/D76266
This commit is contained in:
@@ -42,13 +42,6 @@ using namespace mlir;
|
||||
using llvm::dbgs;
|
||||
using mlir::functional::zipMap;
|
||||
|
||||
static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
|
||||
|
||||
static llvm::cl::opt<bool> lowerToLLVMMatrixIntrinsics(
|
||||
"vector-lower-matrix-intrinsics",
|
||||
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
|
||||
llvm::cl::init(false), llvm::cl::cat(clOptionsCategory));
|
||||
|
||||
/// Given a shape with sizes greater than 0 along all dimensions,
|
||||
/// returns the distance, in number of elements, between a slice in a dimension
|
||||
/// and the next slice in the same dimension.
|
||||
@@ -936,6 +929,11 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
|
||||
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
|
||||
MLIRContext *context)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformsOptions(vectorTransformsOptions) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(ajcbik): implement masks
|
||||
@@ -946,33 +944,41 @@ public:
|
||||
// a new pattern.
|
||||
// TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix
|
||||
// intrinsics, use that.
|
||||
if (lowerToLLVMMatrixIntrinsics &&
|
||||
if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics &&
|
||||
isColumnMajorMatmul(op.indexing_maps())) {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
Type flattenedLHSType =
|
||||
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
|
||||
Type flattenedRHSType =
|
||||
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
|
||||
auto lhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedLHSType, op.lhs());
|
||||
auto rhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedRHSType, op.rhs());
|
||||
|
||||
unsigned lhsRows = op.getLhsType().getShape()[0];
|
||||
unsigned lhsColumns = op.getLhsType().getShape()[1];
|
||||
unsigned rhsColumns = op.getRhsType().getShape()[1];
|
||||
Value mul = rewriter.create<vector::MatmulOp>(
|
||||
op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
|
||||
mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
|
||||
op.acc().getType(), mul);
|
||||
Type elementType = op.getLhsType().getElementType();
|
||||
assert(elementType.isIntOrFloat());
|
||||
if (elementType.isa<IntegerType>())
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
|
||||
return matchSuccess();
|
||||
|
||||
// In cases where matrices are degenerate, scalarization issues occur in
|
||||
// the backend. Avoid all LLVM scalarization issues for now.
|
||||
// For more details, see: https://bugs.llvm.org/show_bug.cgi?id=45227 and
|
||||
// https://bugs.llvm.org/show_bug.cgi?id=45229
|
||||
// TODO(ntv, fhahn): Relax once above bugs are fixed.
|
||||
if (lhsRows != 1 && lhsColumns != 1 && rhsColumns != 1) {
|
||||
Type flattenedLHSType =
|
||||
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
|
||||
Type flattenedRHSType =
|
||||
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
|
||||
auto lhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedLHSType, op.lhs());
|
||||
auto rhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedRHSType, op.rhs());
|
||||
|
||||
Value mul = rewriter.create<vector::MatmulOp>(
|
||||
op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
|
||||
mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
|
||||
op.acc().getType(), mul);
|
||||
Type elementType = op.getLhsType().getElementType();
|
||||
assert(elementType.isIntOrFloat());
|
||||
if (elementType.isa<IntegerType>())
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
|
||||
return matchSuccess();
|
||||
}
|
||||
}
|
||||
|
||||
// Find first batch dimension in LHS/RHS, and lower when found.
|
||||
@@ -1255,6 +1261,8 @@ private:
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
vector::VectorTransformsOptions vectorTransformsOptions;
|
||||
};
|
||||
|
||||
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
|
||||
@@ -1342,8 +1350,10 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorContractLoweringPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<ContractionOpLowering, ShapeCastOp2DDownCastRewritePattern,
|
||||
OwningRewritePatternList &patterns, MLIRContext *context,
|
||||
VectorTransformsOptions parameters) {
|
||||
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern, OuterProductOpLowering>(
|
||||
context);
|
||||
patterns.insert<ContractionOpLowering>(parameters, context);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user