From cf97263cb8cd4f6f21f00eabb9d6a007e221eaab Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 26 Nov 2019 14:43:03 -0800 Subject: [PATCH] [VectorOps] Add a BroadcastOp to the VectorOps dialect PiperOrigin-RevId: 282643305 --- .../mlir/Dialect/VectorOps/VectorOps.td | 28 +++++++++++++ mlir/lib/Dialect/VectorOps/VectorOps.cpp | 41 +++++++++++++++++++ mlir/test/Dialect/VectorOps/invalid.mlir | 7 ++++ mlir/test/Dialect/VectorOps/ops.mlir | 9 ++++ 4 files changed, 85 insertions(+) diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index a887a3e4e792..34c2fa97e536 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -162,6 +162,34 @@ def Vector_ContractionOp : }]; } +def Vector_BroadcastOp : + Vector_Op<"broadcast", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"dest operand and result have same type", + TCresIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyType:$source, AnyVector:$dest)>, + Results<(outs AnyVector:$vector)> { + let summary = "broadcast operation"; + let description = [{ + Broadcasts the scalar or k-D vector value in the source to the n-D + destination vector of a proper shape such that the broadcast makes sense. + + Examples: + ``` + %0 = constant 0.0 : f32 + %1 = vector.broadcast %0, %x : f32 into vector<16xf32> + %2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32> + ``` + }]; + let extraClassDeclaration = [{ + Type getSourceType() { return source()->getType(); } + VectorType getDestVectorType() { + return dest()->getType().cast(); + } + }]; +} + def Vector_ExtractElementOp : Vector_Op<"extractelement", [NoSideEffect, PredOpTrait<"operand and result have same element type", diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index b73b771d80dc..d09fd0fc2f2d 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -368,6 +368,47 @@ static LogicalResult verify(ExtractElementOp op) { return success(); } +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, BroadcastOp op) { + p << op.getOperationName() << " " << *op.source() << ", " << *op.dest(); + p << " : " << op.getSourceType(); + p << " into " << op.getDestVectorType(); +} + +static LogicalResult verify(BroadcastOp op) { + VectorType srcVectorType = op.getSourceType().dyn_cast(); + VectorType dstVectorType = op.getDestVectorType(); + // Scalar to vector broadcast is always valid. A vector + // to vector broadcast needs some additional checking. + if (srcVectorType) { + const int64_t srcRank = srcVectorType.getRank(); + const int64_t dstRank = dstVectorType.getRank(); + // TODO(ajcbik): implement proper rank testing for broadcast; + // this is just a temporary placeholder check. + if (srcRank > dstRank) { + return op.emitOpError("source rank higher than destination rank"); + } + } + return success(); +} + +static ParseResult parseBroadcastOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType source, dest; + Type sourceType; + VectorType destType; + return failure(parser.parseOperand(source) || parser.parseComma() || + parser.parseOperand(dest) || + parser.parseColonType(sourceType) || + parser.parseKeywordType("into", destType) || + parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands) || + parser.addTypeToList(destType, result.types)); +} + //===----------------------------------------------------------------------===// // InsertElementOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 60d57740e696..92e956ef29a1 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -2,6 +2,13 @@ // ----- +func @broadcast_rank_too_high(%arg0: vector<4x4xf32>, %arg1: vector<4xf32>) { + // expected-error@+1 {{source rank higher than destination rank}} + %2 = vector.broadcast %arg0, %arg1 : vector<4x4xf32> into vector<4xf32> +} + +// ----- + func @extract_element_vector_type(%arg0: index) { // expected-error@+1 {{expected vector type}} %1 = vector.extractelement %arg0[] : index diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index a2b1ac34142c..51dbc4f04359 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -22,6 +22,15 @@ func @vector_transfer_ops(%arg0: memref) { return } +// CHECK-LABEL: @vector_broadcast +func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>) { + // CHECK: vector.broadcast %{{.*}}, %{{.*}} : f32 into vector<16xf32> + %0 = vector.broadcast %a, %b : f32 into vector<16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}}, %{{.*}} : vector<16xf32> into vector<8x16xf32> + %1 = vector.broadcast %b, %c : vector<16xf32> into vector<8x16xf32> + return +} + // CHECK-LABEL: @extractelement func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) { // CHECK: vector.extractelement {{.*}}[3 : i32] : vector<4x8x16xf32>