mirror of
https://github.com/intel/llvm.git
synced 2026-01-23 16:06:39 +08:00
[mlir] NFC - VectorTransforms use OpBuilder where relevant
Summary: This will allow using unrolling outside of only rewrite patterns. Differential Revision: https://reviews.llvm.org/D80083
This commit is contained in:
@@ -65,7 +65,7 @@ namespace vector {
|
||||
// This will be extended in the future to support more advanced use cases than
|
||||
// simple pointwise ops.
|
||||
SmallVector<Value, 1>
|
||||
unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
|
||||
unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<int64_t> targetShape);
|
||||
|
||||
} // namespace vector
|
||||
|
||||
@@ -68,8 +68,8 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
|
||||
|
||||
// Clones `op` into a new operations that takes `operands` and returns
|
||||
// `resultTypes`.
|
||||
static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
|
||||
Location loc, Operation *op,
|
||||
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
|
||||
Operation *op,
|
||||
ArrayRef<Value> operands,
|
||||
ArrayRef<Type> resultTypes) {
|
||||
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
|
||||
@@ -98,7 +98,7 @@ static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
|
||||
static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
|
||||
ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> strides,
|
||||
PatternRewriter &builder) {
|
||||
OpBuilder &builder) {
|
||||
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
|
||||
assert(static_cast<int64_t>(sizes.size()) == vectorType.getRank());
|
||||
assert(static_cast<int64_t>(strides.size()) == vectorType.getRank());
|
||||
@@ -140,7 +140,7 @@ static void initUnrolledVectorState(VectorType vectorType, Value initValue,
|
||||
const DenseMap<int64_t, int64_t> &indexMap,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
UnrolledVectorState &state,
|
||||
PatternRewriter &builder) {
|
||||
OpBuilder &builder) {
|
||||
// Compute unrolled shape of 'vectorType'.
|
||||
state.unrolledShape.resize(vectorType.getRank());
|
||||
getMappedElements(indexMap, targetShape, state.unrolledShape);
|
||||
@@ -183,7 +183,7 @@ getUnrolledVectorLinearIndex(UnrolledVectorState &state,
|
||||
static Value getOrCreateUnrolledVectorSlice(
|
||||
Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
|
||||
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
|
||||
Value initValue, SmallVectorImpl<Value> &cache, PatternRewriter &builder) {
|
||||
Value initValue, SmallVectorImpl<Value> &cache, OpBuilder &builder) {
|
||||
// Compute slice offsets.
|
||||
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
|
||||
getMappedElements(indexMap, offsets, sliceOffsets);
|
||||
@@ -275,7 +275,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
|
||||
std::vector<VectorState> &vectors,
|
||||
unsigned resultIndex,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
PatternRewriter &builder) {
|
||||
OpBuilder &builder) {
|
||||
auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
assert(false && "Expected a statically shaped result type");
|
||||
@@ -426,7 +426,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
|
||||
|
||||
// Entry point for unrolling declarative pattern rewrites.
|
||||
SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
|
||||
PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) {
|
||||
OpBuilder &builder, Operation *op, ArrayRef<int64_t> targetShape) {
|
||||
assert(op->getNumResults() == 1 && "Expected single result operation");
|
||||
|
||||
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
|
||||
@@ -451,12 +451,10 @@ SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
|
||||
|
||||
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
|
||||
/// calls 'fn' with linear index and indices for each slice.
|
||||
static void
|
||||
generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
|
||||
TupleType tupleType, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> strides, ArrayRef<Value> indices,
|
||||
PatternRewriter &rewriter,
|
||||
function_ref<void(unsigned, ArrayRef<Value>)> fn) {
|
||||
static void generateTransferOpSlices(
|
||||
Type memrefElementType, VectorType vectorType, TupleType tupleType,
|
||||
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
|
||||
OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
|
||||
// Compute strides w.r.t. to slice counts in each dimension.
|
||||
auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
|
||||
assert(maybeDimSliceCounts.hasValue());
|
||||
@@ -484,7 +482,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
|
||||
}
|
||||
unsigned indexOffset = numSliceIndices - vectorRank;
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto *ctx = builder.getContext();
|
||||
for (unsigned i = 0; i < numSlices; ++i) {
|
||||
auto vectorOffsets = delinearize(sliceStrides, i);
|
||||
auto elementOffsets =
|
||||
@@ -498,7 +496,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
|
||||
auto expr = getAffineDimExpr(0, ctx) +
|
||||
getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
|
||||
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
|
||||
sliceIndices[j] = rewriter.create<AffineApplyOp>(
|
||||
sliceIndices[j] = builder.create<AffineApplyOp>(
|
||||
indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
|
||||
}
|
||||
}
|
||||
@@ -1683,8 +1681,13 @@ public:
|
||||
// TODO(andydavis) Add this as DRR pattern.
|
||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, SplitTransferReadOp,
|
||||
SplitTransferWriteOp, TupleGetFolderOp>(context);
|
||||
// clang-format off
|
||||
patterns.insert<ShapeCastOpDecomposer,
|
||||
ShapeCastOpFolder,
|
||||
SplitTransferReadOp,
|
||||
SplitTransferWriteOp,
|
||||
TupleGetFolderOp>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorSlicesLoweringPatterns(
|
||||
@@ -1695,9 +1698,14 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
|
||||
void mlir::vector::populateVectorContractLoweringPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context,
|
||||
VectorTransformsOptions parameters) {
|
||||
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
|
||||
TransposeOpLowering, OuterProductOpLowering,
|
||||
ConstantMaskOpLowering, CreateMaskOpLowering>(context);
|
||||
// clang-format off
|
||||
patterns.insert<BroadcastOpLowering,
|
||||
CreateMaskOpLowering,
|
||||
ConstantMaskOpLowering,
|
||||
OuterProductOpLowering,
|
||||
ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern,
|
||||
TransposeOpLowering>(context);
|
||||
// clang-format on
|
||||
patterns.insert<ContractionOpLowering>(parameters, context);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user