[mlir] NFC - VectorTransforms use OpBuilder where relevant

Summary: This will allow using unrolling outside of only rewrite patterns.

Differential Revision: https://reviews.llvm.org/D80083
This commit is contained in:
Nicolas Vasilache
2020-05-17 10:15:58 -04:00
parent 6f02633a4f
commit 1d6eb09d22
2 changed files with 30 additions and 22 deletions

View File

@@ -65,7 +65,7 @@ namespace vector {
// This will be extended in the future to support more advanced use cases than
// simple pointwise ops.
SmallVector<Value, 1>
unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape);
} // namespace vector

View File

@@ -68,8 +68,8 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
Location loc, Operation *op,
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
Operation *op,
ArrayRef<Value> operands,
ArrayRef<Type> resultTypes) {
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
@@ -98,7 +98,7 @@ static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides,
PatternRewriter &builder) {
OpBuilder &builder) {
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
assert(static_cast<int64_t>(sizes.size()) == vectorType.getRank());
assert(static_cast<int64_t>(strides.size()) == vectorType.getRank());
@@ -140,7 +140,7 @@ static void initUnrolledVectorState(VectorType vectorType, Value initValue,
const DenseMap<int64_t, int64_t> &indexMap,
ArrayRef<int64_t> targetShape,
UnrolledVectorState &state,
PatternRewriter &builder) {
OpBuilder &builder) {
// Compute unrolled shape of 'vectorType'.
state.unrolledShape.resize(vectorType.getRank());
getMappedElements(indexMap, targetShape, state.unrolledShape);
@@ -183,7 +183,7 @@ getUnrolledVectorLinearIndex(UnrolledVectorState &state,
static Value getOrCreateUnrolledVectorSlice(
Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
Value initValue, SmallVectorImpl<Value> &cache, PatternRewriter &builder) {
Value initValue, SmallVectorImpl<Value> &cache, OpBuilder &builder) {
// Compute slice offsets.
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
getMappedElements(indexMap, offsets, sliceOffsets);
@@ -275,7 +275,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
std::vector<VectorState> &vectors,
unsigned resultIndex,
ArrayRef<int64_t> targetShape,
PatternRewriter &builder) {
OpBuilder &builder) {
auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape())
assert(false && "Expected a statically shaped result type");
@@ -426,7 +426,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
// Entry point for unrolling declarative pattern rewrites.
SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) {
OpBuilder &builder, Operation *op, ArrayRef<int64_t> targetShape) {
assert(op->getNumResults() == 1 && "Expected single result operation");
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
@@ -451,12 +451,10 @@ SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
/// calls 'fn' with linear index and indices for each slice.
static void
generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
TupleType tupleType, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides, ArrayRef<Value> indices,
PatternRewriter &rewriter,
function_ref<void(unsigned, ArrayRef<Value>)> fn) {
static void generateTransferOpSlices(
Type memrefElementType, VectorType vectorType, TupleType tupleType,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
// Compute strides w.r.t. to slice counts in each dimension.
auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
assert(maybeDimSliceCounts.hasValue());
@@ -484,7 +482,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
}
unsigned indexOffset = numSliceIndices - vectorRank;
auto *ctx = rewriter.getContext();
auto *ctx = builder.getContext();
for (unsigned i = 0; i < numSlices; ++i) {
auto vectorOffsets = delinearize(sliceStrides, i);
auto elementOffsets =
@@ -498,7 +496,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
auto expr = getAffineDimExpr(0, ctx) +
getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
sliceIndices[j] = rewriter.create<AffineApplyOp>(
sliceIndices[j] = builder.create<AffineApplyOp>(
indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
}
}
@@ -1683,8 +1681,13 @@ public:
// TODO(andydavis) Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, SplitTransferReadOp,
SplitTransferWriteOp, TupleGetFolderOp>(context);
// clang-format off
patterns.insert<ShapeCastOpDecomposer,
ShapeCastOpFolder,
SplitTransferReadOp,
SplitTransferWriteOp,
TupleGetFolderOp>(context);
// clang-format on
}
void mlir::vector::populateVectorSlicesLoweringPatterns(
@@ -1695,9 +1698,14 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
void mlir::vector::populateVectorContractLoweringPatterns(
OwningRewritePatternList &patterns, MLIRContext *context,
VectorTransformsOptions parameters) {
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
TransposeOpLowering, OuterProductOpLowering,
ConstantMaskOpLowering, CreateMaskOpLowering>(context);
// clang-format off
patterns.insert<BroadcastOpLowering,
CreateMaskOpLowering,
ConstantMaskOpLowering,
OuterProductOpLowering,
ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern,
TransposeOpLowering>(context);
// clang-format on
patterns.insert<ContractionOpLowering>(parameters, context);
}