mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
[VectorOps] Add a BroadcastOp to the VectorOps dialect
PiperOrigin-RevId: 282643305
This commit is contained in:
committed by
A. Unique TensorFlower
parent
18aec3e2e5
commit
cf97263cb8
@@ -162,6 +162,34 @@ def Vector_ContractionOp :
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_BroadcastOp :
|
||||
Vector_Op<"broadcast", [NoSideEffect,
|
||||
PredOpTrait<"source operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"dest operand and result have same type",
|
||||
TCresIsSameAsOpBase<0, 1>>]>,
|
||||
Arguments<(ins AnyType:$source, AnyVector:$dest)>,
|
||||
Results<(outs AnyVector:$vector)> {
|
||||
let summary = "broadcast operation";
|
||||
let description = [{
|
||||
Broadcasts the scalar or k-D vector value in the source to the n-D
|
||||
destination vector of a proper shape such that the broadcast makes sense.
|
||||
|
||||
Examples:
|
||||
```
|
||||
%0 = constant 0.0 : f32
|
||||
%1 = vector.broadcast %0, %x : f32 into vector<16xf32>
|
||||
%2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32>
|
||||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
Type getSourceType() { return source()->getType(); }
|
||||
VectorType getDestVectorType() {
|
||||
return dest()->getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_ExtractElementOp :
|
||||
Vector_Op<"extractelement", [NoSideEffect,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
|
||||
@@ -368,6 +368,47 @@ static LogicalResult verify(ExtractElementOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BroadcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, BroadcastOp op) {
|
||||
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest();
|
||||
p << " : " << op.getSourceType();
|
||||
p << " into " << op.getDestVectorType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(BroadcastOp op) {
|
||||
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
|
||||
VectorType dstVectorType = op.getDestVectorType();
|
||||
// Scalar to vector broadcast is always valid. A vector
|
||||
// to vector broadcast needs some additional checking.
|
||||
if (srcVectorType) {
|
||||
const int64_t srcRank = srcVectorType.getRank();
|
||||
const int64_t dstRank = dstVectorType.getRank();
|
||||
// TODO(ajcbik): implement proper rank testing for broadcast;
|
||||
// this is just a temporary placeholder check.
|
||||
if (srcRank > dstRank) {
|
||||
return op.emitOpError("source rank higher than destination rank");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseBroadcastOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType source, dest;
|
||||
Type sourceType;
|
||||
VectorType destType;
|
||||
return failure(parser.parseOperand(source) || parser.parseComma() ||
|
||||
parser.parseOperand(dest) ||
|
||||
parser.parseColonType(sourceType) ||
|
||||
parser.parseKeywordType("into", destType) ||
|
||||
parser.resolveOperand(source, sourceType, result.operands) ||
|
||||
parser.resolveOperand(dest, destType, result.operands) ||
|
||||
parser.addTypeToList(destType, result.types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
|
||||
// -----
|
||||
|
||||
func @broadcast_rank_too_high(%arg0: vector<4x4xf32>, %arg1: vector<4xf32>) {
|
||||
// expected-error@+1 {{source rank higher than destination rank}}
|
||||
%2 = vector.broadcast %arg0, %arg1 : vector<4x4xf32> into vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_element_vector_type(%arg0: index) {
|
||||
// expected-error@+1 {{expected vector type}}
|
||||
%1 = vector.extractelement %arg0[] : index
|
||||
|
||||
@@ -22,6 +22,15 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_broadcast
|
||||
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>) {
|
||||
// CHECK: vector.broadcast %{{.*}}, %{{.*}} : f32 into vector<16xf32>
|
||||
%0 = vector.broadcast %a, %b : f32 into vector<16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}}, %{{.*}} : vector<16xf32> into vector<8x16xf32>
|
||||
%1 = vector.broadcast %b, %c : vector<16xf32> into vector<8x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extractelement
|
||||
func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
|
||||
// CHECK: vector.extractelement {{.*}}[3 : i32] : vector<4x8x16xf32>
|
||||
|
||||
Reference in New Issue
Block a user