Standardize all VectorOps class names to be prefixed by Vector - NFC

This improves consistency and will concretely avoid collisions between VectorExtractElementOp and ExtractElementOp when they are included in the same transforms / rewrites.

PiperOrigin-RevId: 281101588
This commit is contained in:
Nicolas Vasilache
2019-11-18 10:38:35 -08:00
committed by A. Unique TensorFlower
parent f0f3b71d67
commit 9732bb533c
3 changed files with 51 additions and 36 deletions

View File

@@ -49,7 +49,7 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def ExtractElementOp :
def VectorExtractElementOp :
Vector_Op<"extractelement", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
@@ -66,14 +66,17 @@ def ExtractElementOp :
%2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32>
```
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value *source, ArrayRef<int32_t>">];
let extraClassDeclaration = [{
static StringRef getPositionAttrName() { return "position"; }
VectorType getVectorType() {
return vector()->getType().cast<VectorType>();
}
}];
}
def OuterProductOp :
def VectorOuterProductOp :
Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
Results<(outs AnyVector)> {

View File

@@ -15,9 +15,9 @@
// limitations under the License.
// =============================================================================
#include "mlir/Conversion/VectorConversions/VectorConversions.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/VectorConversions/VectorConversions.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/Attributes.h"
@@ -49,19 +49,19 @@ static LLVM::LLVMType getPtrToElementType(T containerType,
.getPointerTo();
}
class ExtractElementOpConversion : public LLVMOpLowering {
class VectorExtractElementOpConversion : public LLVMOpLowering {
public:
explicit ExtractElementOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
typeConverter) {}
explicit VectorExtractElementOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::VectorExtractElementOp::getOperationName(),
context, typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
auto extractOp = cast<vector::ExtractElementOp>(op);
auto adaptor = vector::VectorExtractElementOpOperandAdaptor(operands);
auto extractOp = cast<vector::VectorExtractElementOp>(op);
auto vectorType = extractOp.vector()->getType().cast<VectorType>();
auto resultType = extractOp.getResult()->getType();
auto llvmResultType = lowering.convertType(resultType);
@@ -103,25 +103,25 @@ public:
}
};
class OuterProductOpConversion : public LLVMOpLowering {
class VectorOuterProductOpConversion : public LLVMOpLowering {
public:
explicit OuterProductOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
typeConverter) {}
explicit VectorOuterProductOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::VectorOuterProductOp::getOperationName(),
context, typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
auto adaptor = vector::VectorOuterProductOpOperandAdaptor(operands);
auto *ctx = op->getContext();
auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
cast<vector::VectorOuterProductOp>(op).getResult()->getType());
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
Value *a = adaptor.lhs(), *b = adaptor.rhs();
Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
@@ -246,8 +246,8 @@ public:
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<ExtractElementOpConversion, OuterProductOpConversion,
VectorTypeCastOpConversion>(
patterns.insert<VectorExtractElementOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}

View File

@@ -44,17 +44,34 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
}
//===----------------------------------------------------------------------===//
// ExtractElementOp
// VectorExtractElementOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ExtractElementOp op) {
static Type inferExtractOpResultType(VectorType vectorType,
ArrayAttr position) {
if (static_cast<int64_t>(position.size()) == vectorType.getRank())
return vectorType.getElementType();
return VectorType::get(vectorType.getShape().drop_front(position.size()),
vectorType.getElementType());
}
void VectorExtractElementOp::build(Builder *builder, OperationState &result,
Value *source, ArrayRef<int32_t> position) {
result.addOperands(source);
auto positionAttr = builder->getI32ArrayAttr(position);
result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
positionAttr));
result.addAttribute(getPositionAttrName(), positionAttr);
}
static void print(OpAsmPrinter &p, VectorExtractElementOp op) {
p << op.getOperationName() << " " << *op.vector() << op.position();
p.printOptionalAttrDict(op.getAttrs(), {"position"});
p << " : " << op.vector()->getType();
}
static ParseResult parseExtractElementOp(OpAsmParser &parser,
OperationState &result) {
static ParseResult parseVectorExtractElementOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc attributeLoc, typeLoc;
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType vector;
@@ -77,19 +94,13 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
attributeLoc,
"expected position attribute of rank smaller than vector");
Type resType =
(static_cast<int64_t>(positionAttr.size()) == vectorType.getRank())
? vectorType.getElementType()
: VectorType::get(
vectorType.getShape().drop_front(positionAttr.size()),
vectorType.getElementType());
Type resType = inferExtractOpResultType(vectorType, positionAttr);
result.attributes = attrs;
return failure(parser.resolveOperand(vector, type, result.operands) ||
parser.addTypeToList(resType, result.types));
}
static LogicalResult verify(ExtractElementOp op) {
static LogicalResult verify(VectorExtractElementOp op) {
auto positionAttr = op.position().getValue();
if (positionAttr.empty())
return op.emitOpError("expected non-empty position attribute");
@@ -107,19 +118,20 @@ static LogicalResult verify(ExtractElementOp op) {
}
return success();
}
//===----------------------------------------------------------------------===//
// OuterProductOp
// VectorOuterProductOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, OuterProductOp op) {
static void print(OpAsmPrinter &p, VectorOuterProductOp op) {
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
if (llvm::size(op.acc()) > 0)
p << ", " << **op.acc().begin();
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
}
static ParseResult parseOuterProductOp(OpAsmParser &parser,
OperationState &result) {
static ParseResult parseVectorOuterProductOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
Type tLHS, tRHS;
if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
@@ -142,7 +154,7 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
parser.addTypeToList(resType, result.types));
}
static LogicalResult verify(OuterProductOp op) {
static LogicalResult verify(VectorOuterProductOp op) {
VectorType vLHS = op.getOperandVectorTypeLHS(),
vRHS = op.getOperandVectorTypeRHS(),
vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();