//===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements convenience types for working with super-vectorization // operations, in particular super-vector loads and stores. // //===----------------------------------------------------------------------===// #include "mlir/VectorOps/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Support/LLVM.h" using namespace mlir; //===----------------------------------------------------------------------===// // VectorOpsDialect //===----------------------------------------------------------------------===// VectorOpsDialect::VectorOpsDialect(MLIRContext *context) : Dialect("vector", context) { addOperations(); } //===----------------------------------------------------------------------===// // VectorTransferReadOp //===----------------------------------------------------------------------===// template static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { SmallVector seen(permutationMap.getNumInputs(), false); for (auto expr : permutationMap.getResults()) { auto dim = expr.dyn_cast(); auto zero = expr.dyn_cast(); if (zero) { if (zero.getValue() != 0) { return emitOpError( "requires a projected permutation_map (at most one dim or the zero " "constant can appear in each result)"); } continue; } if (!dim) { return emitOpError("requires a projected permutation_map (at most one " "dim or the zero constant can appear in each result)"); } if (seen[dim.getPosition()]) { return emitOpError( "requires a permutation_map that is a permutation (found one dim " "used more than once)"); } seen[dim.getPosition()] = true; } return success(); } void VectorTransferReadOp::build(Builder *builder, OperationState *result, VectorType vectorType, Value *srcMemRef, ArrayRef srcIndices, AffineMap permutationMap, Optional paddingValue) { result->addOperands(srcMemRef); result->addOperands(srcIndices); if (paddingValue) { result->addOperands({*paddingValue}); } result->addAttribute(getPermutationMapAttrName(), builder->getAffineMapAttr(permutationMap)); result->addTypes(vectorType); } auto VectorTransferReadOp::getIndices() -> operand_range { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); return {begin, end}; } Optional VectorTransferReadOp::getPaddingValue() { auto memRefRank = getMemRefType().getRank(); if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { return None; } return Optional(getOperand(Offsets::FirstIndexOffset + memRefRank)); } AffineMap VectorTransferReadOp::getPermutationMap() { return getAttrOfType(getPermutationMapAttrName()).getValue(); } void VectorTransferReadOp::print(OpAsmPrinter *p) { *p << getOperationName() << " "; p->printOperand(getMemRef()); *p << "["; p->printOperands(getIndices()); *p << "]"; auto optionalPaddingValue = getPaddingValue(); if (optionalPaddingValue) { *p << ", ("; p->printOperand(*optionalPaddingValue); *p << ")"; } p->printOptionalAttrDict(getAttrs()); *p << " : " << getMemRefType(); *p << ", " << getResultType(); } ParseResult VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; SmallVector paddingInfo; SmallVector types; // Parsing with support for optional paddingValue. if (parser->parseOperand(memrefInfo) || parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || parser->parseTrailingOperandList(paddingInfo, -1, OpAsmParser::Delimiter::Paren) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonTypeList(types)) return failure(); // Resolution. if (types.size() != 2) return parser->emitError(parser->getNameLoc(), "expected 2 types"); MemRefType memrefType = types[0].dyn_cast(); if (!memrefType) return parser->emitError(parser->getNameLoc(), "memRef type expected"); VectorType vectorType = types[1].dyn_cast(); if (!vectorType) return parser->emitError(parser->getNameLoc(), "vector type expected"); // Extract optional paddingValue. // At this point, indexInfo may contain the optional paddingValue, pop it out. if (indexInfo.size() != static_cast(memrefType.getRank())) return parser->emitError(parser->getNameLoc(), "expected " + Twine(memrefType.getRank()) + " indices to the memref"); if (paddingInfo.size() > 1) return parser->emitError(parser->getNameLoc(), "expected at most one padding value"); Type paddingType; bool hasOptionalPaddingValue = !paddingInfo.empty(); if (hasOptionalPaddingValue) { paddingType = vectorType.getElementType(); } auto indexType = parser->getBuilder().getIndexType(); return failure( parser->resolveOperand(memrefInfo, memrefType, result->operands) || parser->resolveOperands(indexInfo, indexType, result->operands) || (hasOptionalPaddingValue && parser->resolveOperand(paddingInfo[0], paddingType, result->operands)) || parser->addTypeToList(vectorType, result->types)); } LogicalResult VectorTransferReadOp::verify() { // Consistency of memref type in function type. if (llvm::empty(getOperands())) { return emitOpError( "requires at least a memref operand followed by 'rank' indices"); } if (!getMemRef()->getType().isa()) { return emitOpError("requires a memref as first operand"); } // Consistency of vector type in function type. if (!getResult()->getType().isa()) { return emitOpError("should have a vector result type in function type: " "memref_type<...xelemental_type>, vector_type"); } // Consistency of elemental types in memref and vector. MemRefType memrefType = getMemRefType(); VectorType vectorType = getResultType(); if (memrefType.getElementType() != vectorType.getElementType()) return emitOpError( "requires memref and vector types of the same elemental type"); // Consistency of number of input types. auto optionalPaddingValue = getPaddingValue(); unsigned expectedNumOperands = Offsets::FirstIndexOffset + memrefType.getRank() + (optionalPaddingValue ? 1 : 0); // Checks on the actual operands and their types. if (getNumOperands() != expectedNumOperands) { return emitOpError("expects ") << expectedNumOperands << " operands (of which " << memrefType.getRank() << " indices)"; } // Consistency of padding value with vector type. if (optionalPaddingValue) { auto paddingValue = *optionalPaddingValue; auto elementalType = paddingValue->getType(); if (!VectorType::isValidElementType(elementalType)) { return emitOpError("requires valid padding vector elemental type"); } if (elementalType != vectorType.getElementType()) { return emitOpError( "requires formal padding and vector of the same elemental type"); } } // Consistency of indices types. unsigned numIndices = 0; for (auto *idx : getIndices()) { if (!idx->getType().isIndex()) { return emitOpError( "index to vector.transfer_read must have 'index' type"); } ++numIndices; } if (numIndices != memrefType.getRank()) { return emitOpError("requires at least a memref operand followed by ") << memrefType.getRank() << " indices"; } // Consistency of AffineMap attribute. if (!getAttrOfType(getPermutationMapAttrName())) { return emitOpError("requires an AffineMapAttr named 'permutation_map'"); } auto permutationMap = getPermutationMap(); if (permutationMap.getNumSymbols() != 0) { return emitOpError("requires a permutation_map without symbols"); } if (permutationMap.getNumInputs() != memrefType.getRank()) { return emitOpError("requires a permutation_map with input dims of the " "same rank as the memref type"); } if (permutationMap.getNumResults() != vectorType.getRank()) { return emitOpError("requires a permutation_map with result dims of the " "same rank as the vector type (") << permutationMap.getNumResults() << " vs " << vectorType.getRank(); } return verifyPermutationMap(permutationMap, [this](Twine t) { return emitOpError(t); }); } //===----------------------------------------------------------------------===// // VectorTransferWriteOp //===----------------------------------------------------------------------===// void VectorTransferWriteOp::build(Builder *builder, OperationState *result, Value *srcVector, Value *dstMemRef, ArrayRef dstIndices, AffineMap permutationMap) { result->addOperands({srcVector, dstMemRef}); result->addOperands(dstIndices); result->addAttribute(getPermutationMapAttrName(), builder->getAffineMapAttr(permutationMap)); } auto VectorTransferWriteOp::getIndices() -> operand_range { auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; auto end = begin + getMemRefType().getRank(); return {begin, end}; } AffineMap VectorTransferWriteOp::getPermutationMap() { return getAttrOfType(getPermutationMapAttrName()).getValue(); } void VectorTransferWriteOp::print(OpAsmPrinter *p) { *p << getOperationName(); *p << " " << *getVector(); *p << ", " << *getMemRef(); *p << "["; p->printOperands(getIndices()); *p << "]"; p->printOptionalAttrDict(getAttrs()); *p << " : "; p->printType(getVectorType()); *p << ", "; p->printType(getMemRefType()); } ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; SmallVector types; auto indexType = parser->getBuilder().getIndexType(); if (parser->parseOperand(storeValueInfo) || parser->parseComma() || parser->parseOperand(memrefInfo) || parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonTypeList(types)) return failure(); if (types.size() != 2) return parser->emitError(parser->getNameLoc(), "expected 2 types"); VectorType vectorType = types[Offsets::VectorOffset].dyn_cast(); if (!vectorType) return parser->emitError(parser->getNameLoc(), "vector type expected"); MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast(); if (!memrefType) return parser->emitError(parser->getNameLoc(), "memRef type expected"); return failure( parser->resolveOperands(storeValueInfo, vectorType, result->operands) || parser->resolveOperands(memrefInfo, memrefType, result->operands) || parser->resolveOperands(indexInfo, indexType, result->operands)); } LogicalResult VectorTransferWriteOp::verify() { // Consistency of memref type in function type. if (llvm::empty(getOperands())) { return emitOpError( "requires at least a memref operand followed by 'rank' indices"); } if (!getMemRef()->getType().isa()) { return emitOpError("requires a memref first operand"); } // Consistency of vector type in function type. if (!getVector()->getType().isa()) { return emitOpError("should have a vector input type in function type: " "(vector_type, memref_type [, elemental_type]) -> ()"); } // Consistency of elemental types in memref and vector. MemRefType memrefType = getMemRefType(); VectorType vectorType = getVectorType(); if (memrefType.getElementType() != vectorType.getElementType()) return emitOpError( "requires memref and vector types of the same elemental type"); // Consistency of number of input types. unsigned expectedNumOperands = Offsets::FirstIndexOffset + memrefType.getRank(); // Checks on the actual operands and their types. if (getNumOperands() != expectedNumOperands) { return emitOpError() << "expects " << expectedNumOperands << " operands (of which " << memrefType.getRank() << " indices)"; } // Consistency of indices types. unsigned numIndices = 0; for (auto *idx : getIndices()) { if (!idx->getType().isIndex()) { return emitOpError( "index to vector.transfer_write must have 'index' type"); } numIndices++; } if (numIndices != memrefType.getRank()) { return emitOpError("requires at least a memref operand followed by ") << memrefType.getRank() << " indices"; } // Consistency of AffineMap attribute. if (!getAttrOfType(getPermutationMapAttrName())) { return emitOpError("requires an AffineMapAttr named 'permutation_map'"); } auto permutationMap = getPermutationMap(); if (permutationMap.getNumSymbols() != 0) { return emitOpError("requires a permutation_map without symbols"); } if (permutationMap.getNumInputs() != memrefType.getRank()) { return emitOpError("requires a permutation_map with input dims of the " "same rank as the memref type"); } if (permutationMap.getNumResults() != vectorType.getRank()) { return emitOpError("requires a permutation_map with result dims of the " "same rank as the vector type (") << permutationMap.getNumResults() << " vs " << vectorType.getRank(); } return verifyPermutationMap(permutationMap, [this](Twine t) { return emitOpError(t); }); } //===----------------------------------------------------------------------===// // VectorTypeCastOp //===----------------------------------------------------------------------===// void VectorTypeCastOp::build(Builder *builder, OperationState *result, Value *srcVector, Type dstType) { result->addOperands(srcVector); result->addTypes(dstType); } ParseResult VectorTypeCastOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType operand; Type srcType, dstType; return failure(parser->parseOperand(operand) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(srcType) || parser->parseComma() || parser->parseType(dstType) || parser->addTypeToList(dstType, result->types) || parser->resolveOperand(operand, srcType, result->operands)); } void VectorTypeCastOp::print(OpAsmPrinter *p) { *p << getOperationName() << ' ' << *getOperand() << " : " << getOperand()->getType() << ", " << getType(); } LogicalResult VectorTypeCastOp::verify() { auto dstMemrefType = getType().dyn_cast(); if (!dstMemrefType) return emitOpError("expects target type to be a memref type"); auto dstVectorType = dstMemrefType.getElementType().dyn_cast(); if (!dstVectorType) return emitOpError( "expects vector as an element of the target memref type"); if (!dstMemrefType.hasStaticShape()) return emitOpError("does not support dynamic shapes"); if (!getOperand()->getType().isa()) return emitOpError("expects source type to be a memref type"); return success(); }