mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
[mlir] Allow vector.contract to have mixed types operands
Allow lhs and rhs to have different type than accumulator/destination. Some hardware like GPUs support natively operations like uint8xuint8xuint32. Differential Revision: https://reviews.llvm.org/D82069
This commit is contained in:
@@ -40,12 +40,9 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
// with operators other than the current set: {*, +}.
|
||||
def Vector_ContractionOp :
|
||||
Vector_Op<"contract", [NoSideEffect,
|
||||
PredOpTrait<"first operand lhs and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"second operand rhs and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 1>>,
|
||||
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
|
||||
PredOpTrait<"third operand acc and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 1>>]>,
|
||||
TCresVTEtIsSameAsOpBase<0, 2>>]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
|
||||
Variadic<VectorOf<[I1]>>:$masks,
|
||||
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
|
||||
@@ -140,6 +137,11 @@ def Vector_ContractionOp :
|
||||
|
||||
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
|
||||
// Vector contraction with mixed typed. lhs/rhs have different element
|
||||
// types than accumulator/result.
|
||||
%6 = vector.contract #contraction_trait %0, %1, %2
|
||||
: vector<10xf16>, vector<10xf16> into f32
|
||||
```
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
@@ -1731,6 +1732,11 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
|
||||
// TODO(ajcbik): implement masks.
|
||||
if (llvm::size(op.masks()) != 0)
|
||||
return failure();
|
||||
// TODO(thomasraoux): support mixed mode contract lowering.
|
||||
if (op.getLhsType().getElementType() !=
|
||||
getElementTypeOrSelf(op.getAccType()) ||
|
||||
op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
|
||||
return failure();
|
||||
|
||||
// TODO(ntv, ajcbik): implement benefits, cost models.
|
||||
MLIRContext *ctx = op.getContext();
|
||||
|
||||
@@ -760,7 +760,7 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||
func @contraction(%arg0: vector<4x3xi32>,
|
||||
%arg1: vector<3x7xf32>,
|
||||
%arg2: vector<4x7xf32>) -> vector<4x7xf32> {
|
||||
// expected-error@+1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}}
|
||||
// expected-error@+1 {{'vector.contract' op failed to verify that lhs and rhs have same element type}}
|
||||
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||
: vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32>
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
|
||||
// CHECK-LABEL: @contraction
|
||||
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
||||
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
|
||||
%arg4 : index) {
|
||||
%arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) {
|
||||
// Test contraction with batch and contracting dims.
|
||||
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2
|
||||
@@ -193,6 +193,10 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
||||
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
|
||||
%rhs_mask
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
|
||||
// Test contraction with mixed type.
|
||||
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
|
||||
%3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
|
||||
: vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user