mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 14:50:42 +08:00
[VectorOps] Add [insert/extract]element definition together with lowering to LLVM
Similar to insert/extract vector instructions but (1) work on 1-D vectors only (2) allow for a dynamic index %c3 = constant 3 : index %0 = vector.insertelement %arg0, %arg1[%c : index] : vector<4xf32> %1 = vector.extractelement %arg0[%c3 : index] : vector<4xf32> PiperOrigin-RevId: 285792205
This commit is contained in:
committed by
A. Unique TensorFlower
parent
73ec37c8bb
commit
cd5dab8ad7
@@ -267,6 +267,33 @@ def Vector_ShuffleOp :
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_ExtractElementOp :
|
||||
Vector_Op<"extractelement", [NoSideEffect,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>]>,
|
||||
Arguments<(ins AnyVector:$vector, Index:$position)>,
|
||||
Results<(outs AnyType)> {
|
||||
let summary = "extractelement operation";
|
||||
let description = [{
|
||||
Takes an 1-D vector and a dynamic index position and extracts the
|
||||
scalar at that position. Note that this instruction resembles
|
||||
vector.extract, but is restricted to 1-D vectors and relaxed
|
||||
to dynamic indices. It is meant to be closer to LLVM's version:
|
||||
https://llvm.org/docs/LangRef.html#extractelement-instruction
|
||||
|
||||
Example:
|
||||
```
|
||||
%c = constant 15 : i32
|
||||
%1 = vector.extractelement %0[%c : i32]: vector<16xf32>
|
||||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_ExtractOp :
|
||||
Vector_Op<"extract", [NoSideEffect,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
@@ -346,6 +373,38 @@ def Vector_ExtractSlicesOp :
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_InsertElementOp :
|
||||
Vector_Op<"insertelement", [NoSideEffect,
|
||||
PredOpTrait<"source operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"dest operand and result have same type",
|
||||
TCresIsSameAsOpBase<0, 1>>]>,
|
||||
Arguments<(ins AnyType:$source, AnyVector:$dest, Index:$position)>,
|
||||
Results<(outs AnyVector)> {
|
||||
let summary = "insertelement operation";
|
||||
let description = [{
|
||||
Takes a scalar source, an 1-D destination vector and a dynamic index
|
||||
position and inserts the source into the destination at the proper
|
||||
position. Note that this instruction resembles vector.insert, but
|
||||
is restricted to 1-D vectors and relaxed to dynamic indices. It is
|
||||
meant to be closer to LLVM's version:
|
||||
https://llvm.org/docs/LangRef.html#insertelement-instruction
|
||||
|
||||
Example:
|
||||
```
|
||||
%c = constant 15 : i32
|
||||
%f = constant 0.0f : f32
|
||||
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
|
||||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
Type getSourceType() { return source()->getType(); }
|
||||
VectorType getDestVectorType() {
|
||||
return dest()->getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_InsertOp :
|
||||
Vector_Op<"insert", [NoSideEffect,
|
||||
PredOpTrait<"source operand and result have same element type",
|
||||
|
||||
@@ -300,6 +300,31 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class VectorExtractElementOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorExtractElementOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
|
||||
auto extractEltOp = cast<vector::ExtractElementOp>(op);
|
||||
auto vectorType = extractEltOp.getVectorType();
|
||||
auto llvmType = lowering.convertType(vectorType.getElementType());
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
||||
op, llvmType, adaptor.vector(), adaptor.position());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorExtractOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorExtractOpConversion(MLIRContext *context,
|
||||
@@ -355,6 +380,31 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class VectorInsertElementOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorInsertElementOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::InsertElementOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
|
||||
auto insertEltOp = cast<vector::InsertElementOp>(op);
|
||||
auto vectorType = insertEltOp.getDestVectorType();
|
||||
auto llvmType = lowering.convertType(vectorType);
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
||||
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorInsertOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorInsertOpConversion(MLIRContext *context,
|
||||
@@ -566,7 +616,8 @@ public:
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractOpConversion, VectorInsertOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
VectorInsertElementOpConversion, VectorInsertOpConversion,
|
||||
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
|
||||
converter.getDialect()->getContext(), converter);
|
||||
}
|
||||
|
||||
@@ -346,6 +346,42 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
|
||||
return res;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector() << "[" << *op.position()
|
||||
<< " : " << op.position()->getType() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType vector, position;
|
||||
Type positionType;
|
||||
VectorType vectorType;
|
||||
if (parser.parseOperand(vector) || parser.parseLSquare() ||
|
||||
parser.parseOperand(position) || parser.parseColonType(positionType) ||
|
||||
parser.parseRSquare() ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(vectorType))
|
||||
return failure();
|
||||
Type resultType = vectorType.getElementType();
|
||||
return failure(
|
||||
parser.resolveOperand(vector, vectorType, result.operands) ||
|
||||
parser.resolveOperand(position, positionType, result.operands) ||
|
||||
parser.addTypeToList(resultType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(vector::ExtractElementOp op) {
|
||||
VectorType vectorType = op.getVectorType();
|
||||
if (vectorType.getRank() != 1)
|
||||
return op.emitOpError("expected 1-D vector");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -684,6 +720,44 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertElementOp op) {
|
||||
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << "["
|
||||
<< *op.position() << " : " << op.position()->getType() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.dest()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType source, dest, position;
|
||||
Type positionType;
|
||||
VectorType destType;
|
||||
if (parser.parseOperand(source) || parser.parseComma() ||
|
||||
parser.parseOperand(dest) || parser.parseLSquare() ||
|
||||
parser.parseOperand(position) || parser.parseColonType(positionType) ||
|
||||
parser.parseRSquare() ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(destType))
|
||||
return failure();
|
||||
Type sourceType = destType.getElementType();
|
||||
return failure(
|
||||
parser.resolveOperand(source, sourceType, result.operands) ||
|
||||
parser.resolveOperand(dest, destType, result.operands) ||
|
||||
parser.resolveOperand(position, positionType, result.operands) ||
|
||||
parser.addTypeToList(destType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertElementOp op) {
|
||||
auto dstVectorType = op.getDestVectorType();
|
||||
if (dstVectorType.getRank() != 1)
|
||||
return op.emitOpError("expected 1-D vector");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -280,6 +280,16 @@ func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
|
||||
// CHECK: %[[i3:.*]] = llvm.insertvalue %[[e3]], %[[i2]][2] : !llvm<"[3 x <4 x float>]">
|
||||
// CHECK: llvm.return %[[i3]] : !llvm<"[3 x <4 x float>]">
|
||||
|
||||
func @extract_element(%arg0: vector<16xf32>) -> f32 {
|
||||
%0 = constant 15 : index
|
||||
%1 = vector.extractelement %arg0[%0 : index]: vector<16xf32>
|
||||
return %1 : f32
|
||||
}
|
||||
// CHECK-LABEL: extract_element(%arg0: !llvm<"<16 x float>">)
|
||||
// CHECK: %[[c:.*]] = llvm.mlir.constant(15 : index) : !llvm.i64
|
||||
// CHECK: %[[x:.*]] = llvm.extractelement %arg0[%[[c]] : !llvm.i64] : !llvm<"<16 x float>">
|
||||
// CHECK: llvm.return %[[x]] : !llvm.float
|
||||
|
||||
func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
|
||||
%0 = vector.extract %arg0[15 : i32]: vector<16xf32>
|
||||
return %0 : f32
|
||||
@@ -315,6 +325,16 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
|
||||
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
|
||||
// CHECK: llvm.return {{.*}} : !llvm.float
|
||||
|
||||
func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
|
||||
%0 = constant 3 : index
|
||||
%1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32>
|
||||
return %1 : vector<4xf32>
|
||||
}
|
||||
// CHECK-LABEL: insert_element(%arg0: !llvm.float, %arg1: !llvm<"<4 x float>">)
|
||||
// CHECK: %[[c:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK: %[[x:.*]] = llvm.insertelement %arg0, %arg1[%[[c]] : !llvm.i64] : !llvm<"<4 x float>">
|
||||
// CHECK: llvm.return %[[x]] : !llvm<"<4 x float>">
|
||||
|
||||
func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
|
||||
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
|
||||
@@ -60,12 +60,20 @@ func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
|
||||
// -----
|
||||
|
||||
func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
|
||||
// expected-error@+1 {{custom op 'vector.shuffle' invalid mask length}}
|
||||
// expected-error@+1 {{'vector.shuffle' invalid mask length}}
|
||||
%1 = vector.shuffle %arg0, %arg1 [] : vector<2xf32>, vector<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element(%arg0: vector<4x4xf32>) {
|
||||
%c = constant 3 : index
|
||||
// expected-error@+1 {{'vector.extractelement' op expected 1-D vector}}
|
||||
%1 = vector.extractelement %arg0[%c : index] : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_vector_type(%arg0: index) {
|
||||
// expected-error@+1 {{expected vector type}}
|
||||
%1 = vector.extract %arg0[] : index
|
||||
@@ -115,6 +123,22 @@ func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) {
|
||||
%c = constant 3 : index
|
||||
// expected-error@+1 {{'vector.insertelement' op expected 1-D vector}}
|
||||
%0 = vector.insertelement %arg0, %arg1[%c : index] : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) {
|
||||
%c = constant 3 : index
|
||||
// expected-error@+1 {{'vector.insertelement' op failed to verify that source operand and result have same element type}}
|
||||
%0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, index) -> (vector<4xf32>)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected non-empty position attribute}}
|
||||
%1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32>
|
||||
|
||||
@@ -53,6 +53,15 @@ func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
|
||||
return %1 : vector<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract_element
|
||||
func @extract_element(%a: vector<16xf32>) -> f32 {
|
||||
// CHECK: %[[C15:.*]] = constant 15 : index
|
||||
%c = constant 15 : index
|
||||
// CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : index] : vector<16xf32>
|
||||
%1 = vector.extractelement %a[%c : index] : vector<16xf32>
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract
|
||||
func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
|
||||
// CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32>
|
||||
@@ -64,6 +73,15 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f
|
||||
return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @insert_element
|
||||
func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> {
|
||||
// CHECK: %[[C15:.*]] = constant 15 : index
|
||||
%c = constant 15 : index
|
||||
// CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : index] : vector<16xf32>
|
||||
%1 = vector.insertelement %a, %b[%c : index] : vector<16xf32>
|
||||
return %1 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @insert
|
||||
func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
|
||||
Reference in New Issue
Block a user