mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 10:08:59 +08:00
[mlir][vector] Fix typo, NFC.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D140681
This commit is contained in:
@@ -591,23 +591,23 @@ private:
|
||||
const vector::UnrollVectorOptions options;
|
||||
};
|
||||
|
||||
struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
|
||||
UnrollTranposePattern(MLIRContext *context,
|
||||
const vector::UnrollVectorOptions &options,
|
||||
PatternBenefit benefit = 1)
|
||||
struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
|
||||
UnrollTransposePattern(MLIRContext *context,
|
||||
const vector::UnrollVectorOptions &options,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<vector::TransposeOp>(context, benefit),
|
||||
options(options) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (tranposeOp.getResultType().getRank() == 0)
|
||||
if (transposeOp.getResultType().getRank() == 0)
|
||||
return failure();
|
||||
auto targetShape = getTargetShape(options, tranposeOp);
|
||||
auto targetShape = getTargetShape(options, transposeOp);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
auto originalVectorType = tranposeOp.getResultType();
|
||||
auto originalVectorType = transposeOp.getResultType();
|
||||
SmallVector<int64_t> strides(targetShape->size(), 1);
|
||||
Location loc = tranposeOp.getLoc();
|
||||
Location loc = transposeOp.getLoc();
|
||||
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
|
||||
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
|
||||
int64_t sliceCount = computeMaxLinearIndex(ratio);
|
||||
@@ -615,7 +615,7 @@ struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
|
||||
SmallVector<int64_t> permutation;
|
||||
tranposeOp.getTransp(permutation);
|
||||
transposeOp.getTransp(permutation);
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
@@ -631,13 +631,14 @@ struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
|
||||
permutedShape[indices.value()] = (*targetShape)[indices.index()];
|
||||
}
|
||||
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
|
||||
Value tranposedSlice =
|
||||
loc, transposeOp.getVector(), permutedOffsets, permutedShape,
|
||||
strides);
|
||||
Value transposedSlice =
|
||||
rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
|
||||
result = rewriter.create<vector::InsertStridedSliceOp>(
|
||||
loc, tranposedSlice, result, elementOffsets, strides);
|
||||
loc, transposedSlice, result, elementOffsets, strides);
|
||||
}
|
||||
rewriter.replaceOp(tranposeOp, result);
|
||||
rewriter.replaceOp(transposeOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -653,5 +654,5 @@ void mlir::vector::populateVectorUnrollPatterns(
|
||||
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
|
||||
UnrollContractionPattern, UnrollElementwisePattern,
|
||||
UnrollReductionPattern, UnrollMultiReductionPattern,
|
||||
UnrollTranposePattern>(patterns.getContext(), options, benefit);
|
||||
UnrollTransposePattern>(patterns.getContext(), options, benefit);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user