From 01145544aad4df7c5f5dcaff2631ef00a01800ce Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 25 Nov 2019 08:46:37 -0800 Subject: [PATCH] Add vector.insertelement op This is the counterpart of vector.extractelement op and has the same limitations at the moment (static I64IntegerArrayAttr to express position). This restriction will be filterd in the future. LLVM lowering will be added in a subsequent commit. PiperOrigin-RevId: 282365760 --- .../mlir/Dialect/VectorOps/VectorOps.td | 34 +++++++++ mlir/include/mlir/IR/OpBase.td | 6 ++ mlir/lib/Dialect/VectorOps/VectorOps.cpp | 74 ++++++++++++++++++- mlir/test/Dialect/VectorOps/invalid.mlir | 41 +++++++++- mlir/test/Dialect/VectorOps/ops.mlir | 11 +++ 5 files changed, 161 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index ce03750b9125..8526367a6d4b 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -169,6 +169,40 @@ def Vector_ExtractElementOp : }]; } +def Vector_InsertElementOp : + Vector_Op<"insertelement", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"dest operand and result have same type", + TCresIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyType:$source, AnyVector:$dest, I32ArrayAttr:$position)>, + Results<(outs AnyVector)> { + let summary = "insertelement operation"; + let description = [{ + Takes an n-D source vector, an (n+k)-D destination vector and a k-D position + and inserts the n-D source into the (n+k)-D destination at the proper + position. Degenerates to a scalar source type when n = 0. + + Examples: + ``` + %2 = vector.insertelement %0, %1[3 : i32]: + vector<8x16xf32> into vector<4x8x16xf32> + %5 = vector.insertelement %3, %4[3 : i32, 3 : i32, 3 : i32]: + f32 into vector<4x8x16xf32> + ``` + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value *source, " # + "Value *dest, ArrayRef">]; + let extraClassDeclaration = [{ + static StringRef getPositionAttrName() { return "position"; } + Type getSourceType() { return source()->getType(); } + VectorType getDestVectorType() { + return dest()->getType().cast(); + } + }]; +} + def Vector_StridedSliceOp : Vector_Op<"strided_slice", [NoSideEffect, PredOpTrait<"operand and result have same element type", diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 32f880df3eec..136c836a95f8 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1668,6 +1668,12 @@ class TCOpResIsShapedTypePred : And<[ SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", IsShapedTypePred>]>; +// Predicate to verify that the i'th result and the j'th operand have the same +// type. +class TCresIsSameAsOpBase : + CPred<"$_op.getResult(" # i # ")->getType() == " + "$_op.getOperand(" # j # ")->getType()">; + // Basic Predicate to verify that the i'th result and the j'th operand have the // same elemental type. class TCresVTEtIsSameAsOpBase : diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 684616f76718..176475af4898 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -291,12 +291,84 @@ static LogicalResult verify(ExtractElementOp op) { attr.getInt() > op.getVectorType().getDimSize(en.index())) return op.emitOpError("expected position attribute #") << (en.index() + 1) - << " to be a positive integer smaller than the corresponding " + << " to be a non-negative integer smaller than the corresponding " "vector dimension"; } return success(); } +//===----------------------------------------------------------------------===// +// InsertElementOp +//===----------------------------------------------------------------------===// + +void InsertElementOp::build(Builder *builder, OperationState &result, + Value *source, Value *dest, + ArrayRef position) { + result.addOperands({source, dest}); + auto positionAttr = builder->getI32ArrayAttr(position); + result.addTypes(dest->getType()); + result.addAttribute(getPositionAttrName(), positionAttr); +} + +static void print(OpAsmPrinter &p, InsertElementOp op) { + p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() + << op.position(); + p.printOptionalAttrDict(op.getAttrs(), + {InsertElementOp::getPositionAttrName()}); + p << " : " << op.getSourceType(); + p << " into " << op.getDestVectorType(); +} + +static ParseResult parseInsertElementOp(OpAsmParser &parser, + OperationState &result) { + SmallVector attrs; + OpAsmParser::OperandType source, dest; + Type sourceType; + VectorType destType; + Attribute attr; + return failure(parser.parseOperand(source) || parser.parseComma() || + parser.parseOperand(dest) || + parser.parseAttribute(attr, + InsertElementOp::getPositionAttrName(), + result.attributes) || + parser.parseOptionalAttrDict(attrs) || + parser.parseColonType(sourceType) || + parser.parseKeywordType("into", destType) || + parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands) || + parser.addTypeToList(destType, result.types)); +} + +static LogicalResult verify(InsertElementOp op) { + auto positionAttr = op.position().getValue(); + if (positionAttr.empty()) + return op.emitOpError("expected non-empty position attribute"); + auto destVectorType = op.getDestVectorType(); + if (positionAttr.size() > static_cast(destVectorType.getRank())) + return op.emitOpError( + "expected position attribute of rank smaller than dest vector rank"); + auto srcVectorType = op.getSourceType().dyn_cast(); + if (srcVectorType && + (static_cast(srcVectorType.getRank()) + positionAttr.size() != + static_cast(destVectorType.getRank()))) + return op.emitOpError("expected position attribute rank + source rank to " + "match dest vector rank"); + else if (!srcVectorType && (positionAttr.size() != + static_cast(destVectorType.getRank()))) + return op.emitOpError( + "expected position attribute rank to match the dest vector rank"); + for (auto en : llvm::enumerate(positionAttr)) { + auto attr = en.value().dyn_cast(); + if (!attr || attr.getInt() < 0 || + attr.getInt() > destVectorType.getDimSize(en.index())) + return op.emitOpError("expected position attribute #") + << (en.index() + 1) + << " to be a non-negative integer smaller than the corresponding " + "dest vector dimension"; + } + return success(); +} + //===----------------------------------------------------------------------===// // StridedSliceOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 3ec8b4eca557..f33e27b320b8 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -31,19 +31,54 @@ func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) { // ----- func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) { - // expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}} + // expected-error@+1 {{expected position attribute #2 to be a non-negative integer smaller than the corresponding vector dimension}} %1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32> } // ----- func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) { - // expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}} + // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}} %1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32> } // ----- +func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) { + // expected-error@+1 {{expected non-empty position attribute}} + %1 = vector.insertelement %a, %b[] : f32 into vector<4x8x16xf32> +} + +// ----- + +func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) { + // expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}} + %1 = vector.insertelement %a, %b[3 : i32,3 : i32,3 : i32,3 : i32,3 : i32,3 : i32] : f32 into vector<4x8x16xf32> +} + +// ----- + +func @insert_element_vector_type(%a: vector<4xf32>, %b: vector<4x8x16xf32>) { + // expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}} + %1 = vector.insertelement %a, %b[3 : i32] : vector<4xf32> into vector<4x8x16xf32> +} + +// ----- + +func @insert_element_vector_type(%a: f32, %b: vector<4x8x16xf32>) { + // expected-error@+1 {{expected position attribute rank to match the dest vector rank}} + %1 = vector.insertelement %a, %b[3 : i32,3 : i32] : f32 into vector<4x8x16xf32> +} + +// ----- + +func @insertelement_position_overflow(%a: f32, %b: vector<4x8x16xf32>) { + // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}} + %1 = vector.insertelement %a, %b[0 : i32, 0 : i32, -1 : i32] : f32 into vector<4x8x16xf32> +} + +// ----- + func @outerproduct_num_operands(%arg0: f32) { // expected-error@+1 {{expected at least 2 operands}} %1 = vector.outerproduct %arg0 : f32, f32 @@ -369,5 +404,3 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> return } - - diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 40b00291a051..2cff590401cd 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -33,6 +33,17 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32 } +// CHECK-LABEL: insertelement +func @insertelement(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) { + // CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> + %1 = vector.insertelement %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32> + // CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32> + %2 = vector.insertelement %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32> + // CHECK: vector.insertelement %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32> + %3 = vector.insertelement %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32> + return +} + // CHECK-LABEL: outerproduct func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> { // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>