mirror of
https://github.com/intel/llvm.git
synced 2026-01-28 09:14:23 +08:00
Standardize all VectorOps class names to be prefixed by Vector - NFC
This improves consistency and will concretely avoid collisions between VectorExtractElementOp and ExtractElementOp when they are included in the same transforms / rewrites. PiperOrigin-RevId: 281101588
This commit is contained in:
committed by
A. Unique TensorFlower
parent
f0f3b71d67
commit
9732bb533c
@@ -49,7 +49,7 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def ExtractElementOp :
|
||||
def VectorExtractElementOp :
|
||||
Vector_Op<"extractelement", [NoSideEffect,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>]>,
|
||||
@@ -66,14 +66,17 @@ def ExtractElementOp :
|
||||
%2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32>
|
||||
```
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value *source, ArrayRef<int32_t>">];
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getPositionAttrName() { return "position"; }
|
||||
VectorType getVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def OuterProductOp :
|
||||
def VectorOuterProductOp :
|
||||
Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
|
||||
Results<(outs AnyVector)> {
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/Conversion/VectorConversions/VectorConversions.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Conversion/VectorConversions/VectorConversions.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
@@ -49,19 +49,19 @@ static LLVM::LLVMType getPtrToElementType(T containerType,
|
||||
.getPointerTo();
|
||||
}
|
||||
|
||||
class ExtractElementOpConversion : public LLVMOpLowering {
|
||||
class VectorExtractElementOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit ExtractElementOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
explicit VectorExtractElementOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::VectorExtractElementOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
|
||||
auto extractOp = cast<vector::ExtractElementOp>(op);
|
||||
auto adaptor = vector::VectorExtractElementOpOperandAdaptor(operands);
|
||||
auto extractOp = cast<vector::VectorExtractElementOp>(op);
|
||||
auto vectorType = extractOp.vector()->getType().cast<VectorType>();
|
||||
auto resultType = extractOp.getResult()->getType();
|
||||
auto llvmResultType = lowering.convertType(resultType);
|
||||
@@ -103,25 +103,25 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class OuterProductOpConversion : public LLVMOpLowering {
|
||||
class VectorOuterProductOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit OuterProductOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
explicit VectorOuterProductOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::VectorOuterProductOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
|
||||
auto adaptor = vector::VectorOuterProductOpOperandAdaptor(operands);
|
||||
auto *ctx = op->getContext();
|
||||
auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
|
||||
auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
|
||||
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
|
||||
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
|
||||
auto llvmArrayOfVectType = lowering.convertType(
|
||||
cast<vector::OuterProductOp>(op).getResult()->getType());
|
||||
cast<vector::VectorOuterProductOp>(op).getResult()->getType());
|
||||
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
|
||||
Value *a = adaptor.lhs(), *b = adaptor.rhs();
|
||||
Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
|
||||
@@ -246,8 +246,8 @@ public:
|
||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<ExtractElementOpConversion, OuterProductOpConversion,
|
||||
VectorTypeCastOpConversion>(
|
||||
patterns.insert<VectorExtractElementOpConversion,
|
||||
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
|
||||
converter.getDialect()->getContext(), converter);
|
||||
}
|
||||
|
||||
|
||||
@@ -44,17 +44,34 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
// VectorExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, ExtractElementOp op) {
|
||||
static Type inferExtractOpResultType(VectorType vectorType,
|
||||
ArrayAttr position) {
|
||||
if (static_cast<int64_t>(position.size()) == vectorType.getRank())
|
||||
return vectorType.getElementType();
|
||||
return VectorType::get(vectorType.getShape().drop_front(position.size()),
|
||||
vectorType.getElementType());
|
||||
}
|
||||
|
||||
void VectorExtractElementOp::build(Builder *builder, OperationState &result,
|
||||
Value *source, ArrayRef<int32_t> position) {
|
||||
result.addOperands(source);
|
||||
auto positionAttr = builder->getI32ArrayAttr(position);
|
||||
result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
|
||||
positionAttr));
|
||||
result.addAttribute(getPositionAttrName(), positionAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, VectorExtractElementOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector() << op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
p << " : " << op.vector()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
static ParseResult parseVectorExtractElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
llvm::SMLoc attributeLoc, typeLoc;
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
OpAsmParser::OperandType vector;
|
||||
@@ -77,19 +94,13 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
attributeLoc,
|
||||
"expected position attribute of rank smaller than vector");
|
||||
|
||||
Type resType =
|
||||
(static_cast<int64_t>(positionAttr.size()) == vectorType.getRank())
|
||||
? vectorType.getElementType()
|
||||
: VectorType::get(
|
||||
vectorType.getShape().drop_front(positionAttr.size()),
|
||||
vectorType.getElementType());
|
||||
|
||||
Type resType = inferExtractOpResultType(vectorType, positionAttr);
|
||||
result.attributes = attrs;
|
||||
return failure(parser.resolveOperand(vector, type, result.operands) ||
|
||||
parser.addTypeToList(resType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExtractElementOp op) {
|
||||
static LogicalResult verify(VectorExtractElementOp op) {
|
||||
auto positionAttr = op.position().getValue();
|
||||
if (positionAttr.empty())
|
||||
return op.emitOpError("expected non-empty position attribute");
|
||||
@@ -107,19 +118,20 @@ static LogicalResult verify(ExtractElementOp op) {
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OuterProductOp
|
||||
// VectorOuterProductOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, OuterProductOp op) {
|
||||
static void print(OpAsmPrinter &p, VectorOuterProductOp op) {
|
||||
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
||||
if (llvm::size(op.acc()) > 0)
|
||||
p << ", " << **op.acc().begin();
|
||||
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseOuterProductOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
static ParseResult parseVectorOuterProductOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
|
||||
Type tLHS, tRHS;
|
||||
if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
|
||||
@@ -142,7 +154,7 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
|
||||
parser.addTypeToList(resType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(OuterProductOp op) {
|
||||
static LogicalResult verify(VectorOuterProductOp op) {
|
||||
VectorType vLHS = op.getOperandVectorTypeLHS(),
|
||||
vRHS = op.getOperandVectorTypeRHS(),
|
||||
vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
|
||||
|
||||
Reference in New Issue
Block a user