[mlir][vector] Fix typo, NFC.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140681
This commit is contained in:
jacquesguan
2022-12-27 12:45:06 +08:00
parent 5a3a527f8a
commit 490c77e46d

View File

@@ -591,23 +591,23 @@ private:
const vector::UnrollVectorOptions options;
};
struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTranposePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTransposePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
if (tranposeOp.getResultType().getRank() == 0)
if (transposeOp.getResultType().getRank() == 0)
return failure();
auto targetShape = getTargetShape(options, tranposeOp);
auto targetShape = getTargetShape(options, transposeOp);
if (!targetShape)
return failure();
auto originalVectorType = tranposeOp.getResultType();
auto originalVectorType = transposeOp.getResultType();
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = tranposeOp.getLoc();
Location loc = transposeOp.getLoc();
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
int64_t sliceCount = computeMaxLinearIndex(ratio);
@@ -615,7 +615,7 @@ struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
Value result = rewriter.create<arith::ConstantOp>(
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
SmallVector<int64_t> permutation;
tranposeOp.getTransp(permutation);
transposeOp.getTransp(permutation);
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
// of multiples of the targetShape.
@@ -631,13 +631,14 @@ struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
permutedShape[indices.value()] = (*targetShape)[indices.index()];
}
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
Value tranposedSlice =
loc, transposeOp.getVector(), permutedOffsets, permutedShape,
strides);
Value transposedSlice =
rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, tranposedSlice, result, elementOffsets, strides);
loc, transposedSlice, result, elementOffsets, strides);
}
rewriter.replaceOp(tranposeOp, result);
rewriter.replaceOp(transposeOp, result);
return success();
}
@@ -653,5 +654,5 @@ void mlir::vector::populateVectorUnrollPatterns(
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTranposePattern>(patterns.getContext(), options, benefit);
UnrollTransposePattern>(patterns.getContext(), options, benefit);
}