//===- VectorToLoops.cpp - Conversion within the Vector dialect -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites as 1->N patterns. // //===----------------------------------------------------------------------===// #include #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/Dialect/VectorOps/VectorTransforms.h" #include "mlir/Dialect/VectorOps/VectorUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "vector-to-vector" using namespace mlir; using llvm::dbgs; using mlir::functional::zipMap; /// Given a shape with sizes greater than 0 along all dimensions, /// returns the distance, in number of elements, between a slice in a dimension /// and the next slice in the same dimension. /// e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1] static SmallVector computeStrides(ArrayRef shape) { if (shape.empty()) return {}; SmallVector tmp; tmp.reserve(shape.size()); int64_t running = 1; for (auto size : llvm::reverse(shape)) { assert(size > 0 && "size must be nonnegative"); tmp.push_back(running); running *= size; } return SmallVector(tmp.rbegin(), tmp.rend()); } static int64_t computeMaxLinearIndex(ArrayRef basis) { if (basis.empty()) return 0; int64_t res = 1; for (auto b : basis) res *= b; return res; } /// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'. static int64_t linearize(ArrayRef offsets, ArrayRef basis) { assert(offsets.size() == basis.size()); int64_t linearIndex = 0; for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) linearIndex += offsets[idx] * basis[idx]; return linearIndex; } // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, Location loc, Operation *op, ArrayRef operands, ArrayRef resultTypes) { OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, op->getAttrs()); return builder.createOperation(res); } // Populates 'resultElements[indexMap[i]]' with elements from 'inputElements[i]' // for each index 'i' in inputElements with a valid mapping in 'indexMap'. static void getMappedElements(const DenseMap &indexMap, ArrayRef inputElements, SmallVectorImpl &resultElements) { assert(indexMap.size() == resultElements.size()); assert(inputElements.size() >= resultElements.size()); for (unsigned i = 0, e = inputElements.size(); i < e; ++i) { auto it = indexMap.find(i); if (it != indexMap.end()) resultElements[it->second] = inputElements[i]; } } // Returns a tuple type with vector element types for each resulting slice // of 'vectorType' unrolled by 'sizes' and 'strides'. // TODO(andydavis) Move this to a utility function and share it with // Extract/InsertSlicesOp verification. static TupleType generateExtractSlicesOpResultType(VectorType vectorType, ArrayRef sizes, ArrayRef strides, PatternRewriter &builder) { assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); assert(static_cast(sizes.size()) == vectorType.getRank()); assert(static_cast(strides.size()) == vectorType.getRank()); // Compute shape ratio of 'shape' and 'sizes'. auto shape = vectorType.getShape(); auto maybeDimSliceCounts = shapeRatio(shape, sizes); assert(maybeDimSliceCounts.hasValue()); auto sliceDimCounts = *maybeDimSliceCounts; // Compute strides w.r.t number of slices in each dimension. auto sliceStrides = computeStrides(sliceDimCounts); int64_t sliceCount = computeMaxLinearIndex(sliceDimCounts); SmallVector vectorTypes(sliceCount); for (unsigned i = 0; i < sliceCount; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); auto elementOffsets = computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets); // Create Vector type and add to 'vectorTypes[i]'. vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType()); } return TupleType::get(vectorTypes, builder.getContext()); } // UnrolledVectorState aggregates per-operand/result vector state required for // unrolling. struct UnrolledVectorState { SmallVector unrolledShape; SmallVector unrollFactors; SmallVector basis; int64_t numInstances; Value slicesTuple; }; // Populates 'state' with unrolled shape, unroll factors, basis and // num unrolled instances for 'vectorType'. static void initUnrolledVectorState(VectorType vectorType, Value initValue, const DenseMap &indexMap, ArrayRef targetShape, UnrolledVectorState &state, PatternRewriter &builder) { // Compute unrolled shape of 'vectorType'. state.unrolledShape.resize(vectorType.getRank()); getMappedElements(indexMap, targetShape, state.unrolledShape); // Compute unroll factors for unrolled shape. auto maybeUnrollFactors = shapeRatio(vectorType.getShape(), state.unrolledShape); assert(maybeUnrollFactors.hasValue()); state.unrollFactors = *maybeUnrollFactors; // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'. state.basis = computeStrides(state.unrollFactors); state.numInstances = computeMaxLinearIndex(state.unrollFactors); state.slicesTuple = nullptr; if (initValue != nullptr) { // Create ExtractSlicesOp. SmallVector sizes(state.unrolledShape); SmallVector strides(state.unrollFactors.size(), 1); auto tupleType = generateExtractSlicesOpResultType(vectorType, sizes, strides, builder); state.slicesTuple = builder.create( initValue.getLoc(), tupleType, initValue, sizes, strides); } } // Computes and returns the linear index of the unrolled vector at // 'vectorOffsets' within the vector represented by 'state'. static int64_t getUnrolledVectorLinearIndex(UnrolledVectorState &state, ArrayRef vectorOffsets, DenseMap &indexMap) { // Compute vector offsets. SmallVector sliceOffsets(state.unrolledShape.size()); getMappedElements(indexMap, vectorOffsets, sliceOffsets); // Compute and return linear index of 'sliceOffsets' w.r.t 'state.basis'. return linearize(sliceOffsets, state.basis); } // Returns an unrolled vector at 'vectorOffsets' within the vector // represented by 'state'. The vector is created from a slice of 'initValue' // if not present in 'cache'. static Value getOrCreateUnrolledVectorSlice( Location loc, UnrolledVectorState &state, ArrayRef vectorOffsets, ArrayRef offsets, DenseMap &indexMap, Value initValue, SmallVectorImpl &cache, PatternRewriter &builder) { // Compute slice offsets. SmallVector sliceOffsets(state.unrolledShape.size()); getMappedElements(indexMap, offsets, sliceOffsets); // TODO(b/144845578) Support non-1 strides. SmallVector sliceStrides(state.unrolledShape.size(), 1); // Compute linear index of 'sliceOffsets' w.r.t 'state.basis'. int64_t sliceLinearIndex = getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap); assert(sliceLinearIndex < static_cast(cache.size())); auto valueSlice = cache[sliceLinearIndex]; if (valueSlice == nullptr) { // Return tuple element at 'sliceLinearIndex'. auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex); auto initValueType = initValue.getType().cast(); auto vectorType = VectorType::get(state.unrolledShape, initValueType.getElementType()); // Initialize 'cache' with slice from 'initValue'. valueSlice = builder.create( loc, vectorType, state.slicesTuple, tupleIndex); // Store value back to 'cache'. cache[sliceLinearIndex] = valueSlice; } return valueSlice; } // VectorState aggregates per-operand/result vector state required for // creating slices of vector operands, and clones of the operation being // unrolled. struct VectorState { // The type of this vector. VectorType type; // Map from iteration space index to vector dimension index. DenseMap indexMap; // Index of this value in operation's operand list (-1 if not an operand). int64_t operandIndex = -1; // Accumulator iterator flag. bool isAcc = false; }; // // unrollSingleResultStructuredOp // // Returns a value representing the result of structured operation 'op' // with iteration bounds 'iterationBounds' unrolled to 'targetShape'. // A list of VectorState objects must be specified in 'vectors', where // each VectorState in the list represents a vector operand or vector result // (if the operation does not have an accumulator operand). // The VectorState at index 'resultIndex' in the list must be the state // associated with the operations single result (i.e. either its accumulator // operand or vector result value). // // Example: // // // Before unrolling // // operand0 operand1 operand2 // \ | / // -------------------- opA -------------------- // // // After unrolling by 2 // // operand0 operand1 operand2 // / \ / \ / \ // slice00 slice01 slice10 slice11 slice20 slice21 // \ | | | / | // -------------------- opA0 -------------------- | // | | | | // \ | | / // -------------------- opA1 ------------------- // | | // \ / // insertslice // | // TODO(andydavis) Add the following canonicalization/simplifcation patterns: // *) Add pattern which matches InsertStridedSlice -> StridedSlice and forwards // InsertStridedSlice operand to StridedSlice. // *) Add pattern which matches SourceOp -> StridedSlice -> UserOp which checks // if there are duplicate identical StridedSlice ops from SourceOp, and // rewrites itself to use the first duplicate. This transformation should // cause users of identifical StridedSlice ops to reuse the same StridedSlice // operation, and leave the duplicate StridedSlice ops with no users // (removable with DCE). // TODO(andydavis) Generalize this to support structured ops beyond // vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType' static Value unrollSingleResultStructuredOp(Operation *op, ArrayRef iterationBounds, std::vector &vectors, unsigned resultIndex, ArrayRef targetShape, PatternRewriter &builder) { auto shapedType = op->getResult(0).getType().dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) assert(false && "Expected a statically shaped result type"); // Compute unroll factors for 'iterationBounds' based on 'targetShape' auto maybeUnrollFactors = shapeRatio(iterationBounds, targetShape); if (!maybeUnrollFactors.hasValue()) assert(false && "Failed to compute unroll factors for target shape"); auto unrollFactors = *maybeUnrollFactors; // Compute unrolled vector state for each vector in 'vectors'. unsigned numVectors = vectors.size(); SmallVector unrolledVectorState(numVectors); for (unsigned i = 0; i < numVectors; ++i) { int64_t operandIndex = vectors[i].operandIndex; auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr; initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap, targetShape, unrolledVectorState[i], builder); } // Compute number of total unrolled instances. auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors); auto sliceStrides = computeStrides(unrollFactors); auto &resultValueState = unrolledVectorState[resultIndex]; auto unrolledResultType = VectorType::get(resultValueState.unrolledShape, shapedType.getElementType()); // Initialize caches for intermediate vector results. std::vector> caches(numVectors); for (unsigned i = 0; i < numVectors; ++i) caches[i].resize(unrolledVectorState[i].numInstances); // Unroll 'numUnrolledInstances' of 'op', storing results in 'caches'. for (unsigned i = 0; i < numUnrolledInstances; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); auto elementOffsets = computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); // Get cached slice (or create slice) for each operand at 'offsets'. SmallVector operands; operands.resize(op->getNumOperands()); for (unsigned i = 0; i < numVectors; ++i) { int64_t operandIndex = vectors[i].operandIndex; if (operandIndex < 0) continue; // Output auto operand = op->getOperand(operandIndex); operands[operandIndex] = getOrCreateUnrolledVectorSlice( op->getLoc(), unrolledVectorState[i], vectorOffsets, elementOffsets, vectors[i].indexMap, operand, caches[i], builder); } // Create op on sliced vector arguments. auto resultVector = cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands, unrolledResultType) ->getResult(0); // Compute linear result index. int64_t linearIndex = getUnrolledVectorLinearIndex( resultValueState, vectorOffsets, vectors[resultIndex].indexMap); // Update result cache at 'linearIndex'. caches[resultIndex][linearIndex] = resultVector; } // Create TupleOp of unrolled result vectors. SmallVector vectorTupleTypes(resultValueState.numInstances); SmallVector vectorTupleValues(resultValueState.numInstances); for (unsigned i = 0; i < resultValueState.numInstances; ++i) { vectorTupleTypes[i] = caches[resultIndex][i].getType().cast(); vectorTupleValues[i] = caches[resultIndex][i]; } TupleType tupleType = builder.getTupleType(vectorTupleTypes); Value tupleOp = builder.create(op->getLoc(), tupleType, vectorTupleValues); // Create InsertSlicesOp(Tuple(result_vectors)). auto resultVectorType = op->getResult(0).getType().cast(); SmallVector sizes(resultValueState.unrolledShape); SmallVector strides(resultValueState.unrollFactors.size(), 1); Value insertSlicesOp = builder.create( op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes), builder.getI64ArrayAttr(strides)); return insertSlicesOp; } static void getVectorContractionOpUnrollState( vector::ContractionOp contractionOp, ArrayRef targetShape, SmallVectorImpl &iterationBounds, std::vector &vectors, unsigned &resultIndex) { // Get contraction op iteration bounds. contractionOp.getIterationBounds(iterationBounds); assert(iterationBounds.size() == targetShape.size()); // Get map from iteration space index to lhs/rhs/result shape index. std::vector> iterationIndexMapList; contractionOp.getIterationIndexMap(iterationIndexMapList); unsigned numIterators = iterationIndexMapList.size(); vectors.resize(numIterators); unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex(); for (unsigned i = 0; i < numIterators; ++i) { vectors[i].type = contractionOp.getOperand(i).getType().cast(); vectors[i].indexMap = iterationIndexMapList[i]; vectors[i].operandIndex = i; vectors[i].isAcc = i == accOperandIndex ? true : false; } if (llvm::size(contractionOp.masks()) == 2) { // Add vectors for lhs/rhs vector mask arguments. Masks have the // same vector shape lhs/rhs args, so copy their index maps. vectors.push_back({contractionOp.getLHSVectorMaskType(), vectors[0].indexMap, accOperandIndex + 1, false}); vectors.push_back({contractionOp.getRHSVectorMaskType(), vectors[1].indexMap, accOperandIndex + 2, false}); } // Unroll 'op' 'iterationBounds' to 'targetShape'. // TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition // 'vectors' instead of 'resultIndex'. resultIndex = accOperandIndex; } static void getVectorElementwiseOpUnrollState(Operation *op, ArrayRef targetShape, SmallVectorImpl &iterationBounds, std::vector &vectors, unsigned &resultIndex) { // Verify that operation and operands all have the same vector shape. auto resultType = op->getResult(0).getType().dyn_cast_or_null(); assert(resultType && "Expected op with vector result type"); auto resultShape = resultType.getShape(); // Verify that all operands have the same vector type as result. assert(llvm::all_of(op->getOperandTypes(), [=](Type type) { return type == resultType; })); // Populate 'iterationBounds' with 'resultShape' for elementwise operations. iterationBounds.assign(resultShape.begin(), resultShape.end()); // Create trivial elementwise identity index map based on 'resultShape'. DenseMap indexMap; indexMap.reserve(resultShape.size()); for (unsigned i = 0; i < resultShape.size(); ++i) indexMap[i] = i; // Create VectorState each operand and single result. unsigned numVectors = op->getNumOperands() + op->getNumResults(); vectors.resize(numVectors); for (unsigned i = 0; i < op->getNumOperands(); ++i) vectors[i] = {resultType, indexMap, i, false}; vectors[numVectors - 1] = {resultType, indexMap, -1, false}; resultIndex = numVectors - 1; } // Entry point for unrolling declarative pattern rewrites. SmallVector mlir::vector::unrollSingleResultOpMatchingType( PatternRewriter &builder, Operation *op, ArrayRef targetShape) { assert(op->getNumResults() == 1 && "Expected single result operation"); // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'. SmallVector iterationBounds; std::vector vectors; unsigned resultIndex; if (auto contractionOp = dyn_cast(op)) { // Populate state for vector ContractionOp. getVectorContractionOpUnrollState(contractionOp, targetShape, iterationBounds, vectors, resultIndex); } else { // Populate state for vector elementwise op. getVectorElementwiseOpUnrollState(op, targetShape, iterationBounds, vectors, resultIndex); } // Unroll 'op' with 'iterationBounds' to 'targetShape'. return SmallVector{unrollSingleResultStructuredOp( op, iterationBounds, vectors, resultIndex, targetShape, builder)}; } /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and /// calls 'fn' with linear index and indices for each slice. static void generateTransferOpSlices(Type memrefElementType, VectorType vectorType, TupleType tupleType, ArrayRef sizes, ArrayRef strides, ArrayRef indices, PatternRewriter &rewriter, function_ref)> fn) { // Compute strides w.r.t. to slice counts in each dimension. auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); assert(maybeDimSliceCounts.hasValue()); auto sliceDimCounts = *maybeDimSliceCounts; auto sliceStrides = computeStrides(sliceDimCounts); int64_t numSlices = tupleType.size(); unsigned numSliceIndices = indices.size(); // Compute 'indexOffset' at which to update 'indices', which is equal // to the memref rank (indices.size) minus the effective 'vectorRank'. // The effective 'vectorRank', is equal to the rank of the vector type // minus the rank of the memref vector element type (if it has one). // // For example: // // Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector // transfer_read/write ops which read/write vectors of type // 'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective // vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1. // unsigned vectorRank = vectorType.getRank(); if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { assert(vectorRank >= memrefVectorElementType.getRank()); vectorRank -= memrefVectorElementType.getRank(); } unsigned indexOffset = numSliceIndices - vectorRank; auto *ctx = rewriter.getContext(); for (unsigned i = 0; i < numSlices; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); auto elementOffsets = computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. SmallVector sliceIndices(numSliceIndices); for (unsigned j = 0; j < numSliceIndices; ++j) { if (j < indexOffset) { sliceIndices[j] = indices[j]; } else { auto expr = getAffineDimExpr(0, ctx) + getAffineConstantExpr(elementOffsets[j - indexOffset], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); sliceIndices[j] = rewriter.create( indices[j].getLoc(), map, ArrayRef(indices[j])); } } // Call 'fn' to generate slice 'i' at 'sliceIndices'. fn(i, sliceIndices); } } /// Returns true if 'map' is a suffix of an identity affine map, false /// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)> static bool isIdentitySuffix(AffineMap map) { if (map.getNumDims() < map.getNumResults()) return false; ArrayRef results = map.getResults(); Optional lastPos; for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { auto expr = results[i].dyn_cast(); if (!expr) return false; int currPos = static_cast(expr.getPosition()); if (lastPos.hasValue() && currPos != lastPos.getValue() + 1) return false; lastPos = currPos; } return true; } namespace { // Splits vector TransferReadOp into smaller TransferReadOps based on slicing // scheme of its unique ExtractSlicesOp user. struct SplitTransferReadOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp, PatternRewriter &rewriter) const override { // TODO(andydavis, ntv) Support splitting TransferReadOp with non-identity // permutation maps. Repurpose code from MaterializeVectors transformation. if (!isIdentitySuffix(xferReadOp.permutation_map())) return matchFailure(); // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. Value xferReadResult = xferReadOp.getResult(); auto extractSlicesOp = dyn_cast(*xferReadResult.getUsers().begin()); if (!xferReadResult.hasOneUse() || !extractSlicesOp) return matchFailure(); // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user. auto sourceVectorType = extractSlicesOp.getSourceVectorType(); auto resultTupleType = extractSlicesOp.getResultTupleType(); SmallVector sizes; extractSlicesOp.getSizes(sizes); SmallVector strides; extractSlicesOp.getStrides(strides); assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); Location loc = xferReadOp.getLoc(); auto memrefElementType = xferReadOp.memref().getType().cast().getElementType(); int64_t numSlices = resultTupleType.size(); SmallVector vectorTupleValues(numSlices); SmallVector indices(xferReadOp.indices().begin(), xferReadOp.indices().end()); auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Get VectorType for slice 'i'. auto sliceVectorType = resultTupleType.getType(index); // Create split TransferReadOp for 'sliceUser'. vectorTupleValues[index] = rewriter.create( loc, sliceVectorType, xferReadOp.memref(), sliceIndices, xferReadOp.permutation_map(), xferReadOp.padding()); }; generateTransferOpSlices(memrefElementType, sourceVectorType, resultTupleType, sizes, strides, indices, rewriter, createSlice); // Create tuple of splice xfer read operations. Value tupleOp = rewriter.create(loc, resultTupleType, vectorTupleValues); // Replace 'xferReadOp' with result 'insertSlicesResult'. rewriter.replaceOpWithNewOp( xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(), extractSlicesOp.strides()); return matchSuccess(); } }; // Splits vector TransferWriteOp into smaller TransferWriteOps for each source. struct SplitTransferWriteOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(vector::TransferWriteOp xferWriteOp, PatternRewriter &rewriter) const override { // TODO(andydavis, ntv) Support splitting TransferWriteOp with non-identity // permutation maps. Repurpose code from MaterializeVectors transformation. if (!isIdentitySuffix(xferWriteOp.permutation_map())) return matchFailure(); // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'. auto *vectorDefOp = xferWriteOp.vector().getDefiningOp(); auto insertSlicesOp = dyn_cast_or_null(vectorDefOp); if (!insertSlicesOp) return matchFailure(); // Get TupleOp operand of 'insertSlicesOp'. auto tupleOp = dyn_cast_or_null( insertSlicesOp.vectors().getDefiningOp()); if (!tupleOp) return matchFailure(); // Get 'sizes' and 'strides' parameters from InsertSlicesOp user. auto sourceTupleType = insertSlicesOp.getSourceTupleType(); auto resultVectorType = insertSlicesOp.getResultVectorType(); SmallVector sizes; insertSlicesOp.getSizes(sizes); SmallVector strides; insertSlicesOp.getStrides(strides); Location loc = xferWriteOp.getLoc(); auto memrefElementType = xferWriteOp.memref().getType().cast().getElementType(); SmallVector indices(xferWriteOp.indices().begin(), xferWriteOp.indices().end()); auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. rewriter.create( loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, xferWriteOp.permutation_map()); }; generateTransferOpSlices(memrefElementType, resultVectorType, sourceTupleType, sizes, strides, indices, rewriter, createSlice); // Erase old 'xferWriteOp'. rewriter.eraseOp(xferWriteOp); return matchSuccess(); } }; /// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each /// on vector types. struct ShapeCastOpDecomposer : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { // Check if 'shapeCastOp' has tuple source/result type. auto sourceTupleType = shapeCastOp.source().getType().dyn_cast_or_null(); auto resultTupleType = shapeCastOp.result().getType().dyn_cast_or_null(); if (!sourceTupleType || !resultTupleType) return matchFailure(); assert(sourceTupleType.size() == resultTupleType.size()); // Create single-vector ShapeCastOp for each source tuple element. Location loc = shapeCastOp.getLoc(); SmallVector resultElements; resultElements.reserve(resultTupleType.size()); for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) { auto sourceElement = rewriter.create( loc, sourceTupleType.getType(i), shapeCastOp.source(), rewriter.getI64IntegerAttr(i)); resultElements.push_back(rewriter.create( loc, resultTupleType.getType(i), sourceElement)); } // Replace 'shapeCastOp' with tuple of 'resultElements'. rewriter.replaceOpWithNewOp(shapeCastOp, resultTupleType, resultElements); return matchSuccess(); } }; /// ShapeCastOpFolder folds cancelling ShapeCastOps away. // // Example: // // The following MLIR with cancelling ShapeCastOps: // // %0 = source : vector<5x4x2xf32> // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> // %3 = user %2 : vector<5x4x2xf32> // // Should canonicalize to the following: // // %0 = source : vector<5x4x2xf32> // %1 = user %0 : vector<5x4x2xf32> // struct ShapeCastOpFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { // Check if 'shapeCastOp' has vector source/result type. auto sourceVectorType = shapeCastOp.source().getType().dyn_cast_or_null(); auto resultVectorType = shapeCastOp.result().getType().dyn_cast_or_null(); if (!sourceVectorType || !resultVectorType) return matchFailure(); // Check if shape cast op source operand is also a shape cast op. auto sourceShapeCastOp = dyn_cast_or_null( shapeCastOp.source().getDefiningOp()); if (!sourceShapeCastOp) return matchFailure(); auto operandSourceVectorType = sourceShapeCastOp.source().getType().cast(); auto operandResultVectorType = sourceShapeCastOp.result().getType().cast(); // Check if shape cast operations invert each other. if (operandSourceVectorType != resultVectorType || operandResultVectorType != sourceVectorType) return matchFailure(); rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); return matchSuccess(); } }; // Patter rewrite which forward tuple elements to their users. // User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer))))) // -> User(Producer) struct TupleGetFolderOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(vector::TupleGetOp tupleGetOp, PatternRewriter &rewriter) const override { // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp. auto extractSlicesOp = dyn_cast_or_null( tupleGetOp.vectors().getDefiningOp()); if (!extractSlicesOp) return matchFailure(); // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp. auto insertSlicesOp = dyn_cast_or_null( extractSlicesOp.vector().getDefiningOp()); if (!insertSlicesOp) return matchFailure(); // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp. auto tupleOp = dyn_cast_or_null( insertSlicesOp.vectors().getDefiningOp()); if (!tupleOp) return matchFailure(); // Forward Value from 'tupleOp' at 'tupleGetOp.index'. Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); rewriter.replaceOp(tupleGetOp, tupleValue); return matchSuccess(); } }; /// Progressive lowering of ExtractSlicesOp to tuple of StridedSliceOp. /// One: /// %x = vector.extract_slices %0 /// is replaced by: /// %a = vector.strided_slice %0 /// %b = vector.strided_slice %0 /// .. /// %x = vector.tuple %a, %b, .. class ExtractSlicesOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(vector::ExtractSlicesOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); VectorType vectorType = op.getSourceVectorType(); auto shape = vectorType.getShape(); SmallVector sizes; op.getSizes(sizes); SmallVector strides; op.getStrides(strides); // all-ones at the moment // For each element in the tuple, generate the proper strided slice. TupleType tupleType = op.getResultTupleType(); int64_t tupleSize = tupleType.size(); SmallVector tupleValues(tupleSize); auto sliceStrides = computeStrides(shape, sizes); for (int64_t i = 0; i < tupleSize; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); auto elementOffsets = computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); auto sliceSizes = computeSliceSizes(shape, sizes, elementOffsets); // Insert in tuple. tupleValues[i] = rewriter.create( loc, op.vector(), elementOffsets, sliceSizes, strides); } rewriter.replaceOpWithNewOp(op, tupleType, tupleValues); return matchSuccess(); } }; /// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp. /// One: /// %x = vector.insert_slices %0 /// is replaced by: /// %r0 = vector.splat 0 // %t1 = vector.tuple_get %0, 0 /// %r1 = vector.insert_strided_slice %r0, %t1 // %t2 = vector.tuple_get %0, 1 /// %r2 = vector.insert_strided_slice %r1, %t2 /// .. /// %x = .. class InsertSlicesOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(vector::InsertSlicesOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); VectorType vectorType = op.getResultVectorType(); auto shape = vectorType.getShape(); SmallVector sizes; op.getSizes(sizes); SmallVector strides; op.getStrides(strides); // all-ones at the moment // Prepare result. auto elemType = vectorType.getElementType(); Value zero = rewriter.create(loc, elemType, rewriter.getZeroAttr(elemType)); Value result = rewriter.create(loc, vectorType, zero); // For each element in the tuple, extract the proper strided slice. TupleType tupleType = op.getSourceTupleType(); int64_t tupleSize = tupleType.size(); auto sliceStrides = computeStrides(shape, sizes); for (int64_t i = 0; i < tupleSize; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); auto elementOffsets = computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); // Extract from tuple into the result. auto index = rewriter.getI64IntegerAttr(i); auto tupleGet = rewriter.create( loc, tupleType.getType(i), op.getOperand(), index); result = rewriter.create( loc, tupleGet, result, elementOffsets, strides); } rewriter.replaceOp(op, result); return matchSuccess(); } }; /// Progressive lowering of ConstractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension /// is replaced by: /// %a = vector.contract with one less free/batch dimension /// %b = vector.contract with one less free/batch dimension /// .. /// %x = combine %a %b .. /// until a pure contraction is reached (no free/batch dimensions), /// which is replaced by a fma/reduction op. /// /// TODO(ajcbik): break down into transpose/reshape/cast ops /// when they become available to avoid code dup /// TODO(ajcbik): investigate lowering order impact on performance class ContractionOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { // TODO(ajcbik): implement masks if (llvm::size(op.masks()) != 0) return matchFailure(); // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); if (!batchDimMap.empty()) { int64_t lhsIndex = batchDimMap[0].first; int64_t rhsIndex = batchDimMap[0].second; rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); return matchSuccess(); } // Collect contracting dimensions. std::vector> contractingDimMap = op.getContractingDimMap(); DenseSet lhsContractingDimSet; DenseSet rhsContractingDimSet; for (auto &dimPair : contractingDimMap) { lhsContractingDimSet.insert(dimPair.first); rhsContractingDimSet.insert(dimPair.second); } // Find first free dimension in LHS, and lower when found. VectorType lhsType = op.getLhsType(); for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { if (lhsContractingDimSet.count(lhsIndex) == 0) { rewriter.replaceOp( op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); return matchSuccess(); } } // Find first free dimension in RHS, and lower when found. VectorType rhsType = op.getRhsType(); for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { if (rhsContractingDimSet.count(rhsIndex) == 0) { rewriter.replaceOp( op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); return matchSuccess(); } } // Lower the first remaining reduction dimension. if (!contractingDimMap.empty()) { rewriter.replaceOp(op, lowerReduction(op, rewriter)); return matchSuccess(); } return matchFailure(); } private: // Lower one parallel dimension. // TODO(ajcbik): consider reusing existing contract unrolling Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, int64_t rhsIndex, PatternRewriter &rewriter) const { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); VectorType resType = op.getResultType().cast(); // Find the iterator type index and result index. SmallVector iMap = op.getIndexingMaps(); int64_t iterIndex = -1; int64_t dimSize = -1; if (lhsIndex >= 0) { iterIndex = iMap[0].getResult(lhsIndex).cast().getPosition(); assert((rhsIndex < 0 || iterIndex == iMap[1] .getResult(rhsIndex) .cast() .getPosition()) && "parallel index should be free in LHS or batch in LHS/RHS"); dimSize = lhsType.getDimSize(lhsIndex); } else { assert(rhsIndex >= 0 && "missing parallel index"); iterIndex = iMap[1].getResult(rhsIndex).cast().getPosition(); dimSize = rhsType.getDimSize(rhsIndex); } assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); Optional lookup = getResultIndex(iMap[2], iterIndex); assert(lookup.hasValue() && "parallel index not listed in reduction"); int64_t resIndex = lookup.getValue(); // Construct new iterator types and affine map array attribute. SmallVector lowIndexingMaps; lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); auto lowIter = rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); // Unroll into a series of lower dimensional vector.contract ops. Location loc = op.getLoc(); Value result = zeroVector(loc, resType, rewriter); for (int64_t d = 0; d < dimSize; ++d) { auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); Value lowContract = rewriter.create( loc, lhs, rhs, acc, lowAffine, lowIter); result = reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); } return result; } // Lower one reduction dimension. Value lowerReduction(vector::ContractionOp op, PatternRewriter &rewriter) const { auto loc = op.getLoc(); VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); Type resType = op.getResultType(); assert(!resType.isa()); // Use iterator index 0. int64_t iterIndex = 0; SmallVector iMap = op.getIndexingMaps(); Optional lookupLhs = getResultIndex(iMap[0], iterIndex); Optional lookupRhs = getResultIndex(iMap[1], iterIndex); assert(lookupLhs.hasValue() && "missing LHS parallel index"); assert(lookupRhs.hasValue() && "missing RHS parallel index"); int64_t lhsIndex = lookupLhs.getValue(); int64_t rhsIndex = lookupRhs.getValue(); int64_t dimSize = lhsType.getDimSize(lhsIndex); assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); // Base case. if (lhsType.getRank() == 1) { assert(rhsType.getRank() == 1 && "corrupt contraction"); Value zero = zeroVector(loc, lhsType, rewriter); Value fma = rewriter.create(loc, op.lhs(), op.rhs(), zero); StringAttr kind = rewriter.getStringAttr("add"); return rewriter.create(loc, resType, kind, fma, op.acc()); } // Construct new iterator types and affine map array attribute. SmallVector lowIndexingMaps; lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); auto lowIter = rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); // Unroll into a series of lower dimensional vector.contract ops. // By feeding the initial accumulator into the first contraction, // and the result of each contraction into the next, eventually // the sum of all reductions is computed. Value result = op.acc(); for (int64_t d = 0; d < dimSize; ++d) { auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); result = rewriter.create(loc, lhs, rhs, result, lowAffine, lowIter); } return result; } // Helper method to construct a zero vector. static Value zeroVector(Location loc, VectorType vType, PatternRewriter &rewriter) { Type eltType = vType.getElementType(); Value zero = rewriter.create(loc, eltType, rewriter.getZeroAttr(eltType)); return rewriter.create(loc, vType, zero); } // Helper to find an index in an affine map. static Optional getResultIndex(AffineMap map, int64_t index) { for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { int64_t idx = map.getResult(i).cast().getPosition(); if (idx == index) return i; } return None; } // Helper to construct iterator types with one index removed. static SmallVector adjustIter(ArrayAttr iteratorTypes, int64_t index) { SmallVector results; for (auto it : llvm::enumerate(iteratorTypes)) { int64_t idx = it.index(); if (idx == index) { continue; } results.push_back(it.value()); } return results; } // Helper to construct an affine map with one index removed. static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter) { SmallVector results; for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { int64_t idx = map.getResult(i).cast().getPosition(); if (idx == index) continue; // Re-insert remaining indices, but renamed when occurring // after the removed index. auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, rewriter.getContext()); results.push_back(targetExpr); } // Since (...) -> () cannot be represented properly, // we resort to an empty map when this situation happens. return results.empty() ? AffineMap::get(rewriter.getContext()) : AffineMap::get(map.getNumDims() - 1, 0, results); } // Helper to drop dimension from vector type. static Type adjustType(VectorType tp, int64_t index) { int64_t rank = tp.getRank(); Type eltType = tp.getElementType(); if (rank == 1) { assert(index == 0 && "index for scalar result out of bounds"); return eltType; } SmallVector adjustedShape; for (int64_t i = 0; i < rank; ++i) { // Omit dimension at the given index. if (i == index) continue; // Otherwise, add dimension back. adjustedShape.push_back(tp.getDimSize(i)); } return VectorType::get(adjustedShape, eltType); } // Helper method to possibly drop a dimension in a load. // TODO(ajcbik): use a reshaping vector load (and share lowering code) static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter) { if (index == -1) return val; Type lowType = adjustType(type, 0); // At extraction dimension? if (index == 0) { auto posAttr = rewriter.getI64ArrayAttr(pos); return rewriter.create(loc, lowType, val, posAttr); } // Unroll leading dimensions. VectorType vType = lowType.cast(); VectorType resType = adjustType(type, index).cast(); Value result = zeroVector(loc, resType, rewriter); for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { auto posAttr = rewriter.getI64ArrayAttr(d); Value ext = rewriter.create(loc, vType, val, posAttr); Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); result = rewriter.create(loc, resType, load, result, posAttr); } return result; } // Helper method to possibly drop a dimension in a store. // TODO(ajcbik): use a reshaping vector store (and share lowering code) static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter) { // Unmodified? if (index == -1) return val; // At insertion dimension? if (index == 0) { auto posAttr = rewriter.getI64ArrayAttr(pos); return rewriter.create(loc, type, val, result, posAttr); } // Unroll leading dimensions. Type lowType = adjustType(type, 0); VectorType vType = lowType.cast(); Type insType = adjustType(vType, 0); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { auto posAttr = rewriter.getI64ArrayAttr(d); Value ext = rewriter.create(loc, vType, result, posAttr); Value ins = rewriter.create(loc, insType, val, posAttr); Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); result = rewriter.create(loc, type, sto, result, posAttr); } return result; } }; } // namespace // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO(andydavis) Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } void mlir::vector::populateVectorSlicesLoweringPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } void mlir::vector::populateVectorContractLoweringPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); }