[mlir] Allow vector.contract to have mixed types operands

Allow lhs and rhs to have different type than accumulator/destination. Some
hardware like GPUs support natively operations like uint8xuint8xuint32.

Differential Revision: https://reviews.llvm.org/D82069
This commit is contained in:
Thomas Raoux
2020-06-19 17:08:57 -07:00
parent c310bf8256
commit e4bc08f012
4 changed files with 19 additions and 7 deletions

View File

@@ -40,12 +40,9 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
// with operators other than the current set: {*, +}.
def Vector_ContractionOp :
Vector_Op<"contract", [NoSideEffect,
PredOpTrait<"first operand lhs and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand rhs and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
PredOpTrait<"third operand acc and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>]>,
TCresVTEtIsSameAsOpBase<0, 2>>]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
@@ -140,6 +137,11 @@ def Vector_ContractionOp :
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
// Vector contraction with mixed typed. lhs/rhs have different element
// types than accumulator/result.
%6 = vector.contract #contraction_trait %0, %1, %2
: vector<10xf16>, vector<10xf16> into f32
```
}];
let builders = [OpBuilder<

View File

@@ -28,6 +28,7 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/CommandLine.h"
@@ -1731,6 +1732,11 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO(ajcbik): implement masks.
if (llvm::size(op.masks()) != 0)
return failure();
// TODO(thomasraoux): support mixed mode contract lowering.
if (op.getLhsType().getElementType() !=
getElementTypeOrSelf(op.getAccType()) ||
op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
return failure();
// TODO(ntv, ajcbik): implement benefits, cost models.
MLIRContext *ctx = op.getContext();

View File

@@ -760,7 +760,7 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
func @contraction(%arg0: vector<4x3xi32>,
%arg1: vector<3x7xf32>,
%arg2: vector<4x7xf32>) -> vector<4x7xf32> {
// expected-error@+1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}}
// expected-error@+1 {{'vector.contract' op failed to verify that lhs and rhs have same element type}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32>
}

View File

@@ -175,7 +175,7 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
// CHECK-LABEL: @contraction
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
%arg4 : index) {
%arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) {
// Test contraction with batch and contracting dims.
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2
@@ -193,6 +193,10 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
%rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
// Test contraction with mixed type.
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
%3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
: vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
return
}