mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
[mlir][SVE] Add an e2e test for vector.contract (#69845)
Adds an end-to-end test for `vector.contract` that targets SVE (i.e.
scalable vectors). Note that this requires lifting the restriction on
`vector.outerproduct` (to which `vector.contract` is lowered to) that
would deem the following as invalid by the Op verifier (*):
```
vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<3xf32>, vector<[2]xf32>
```
This is indeed valid as the end-to-end test demonstrates (at least when
compiling for SVE).
This commit is contained in:
committed by
GitHub
parent
b13e9efcef
commit
f24c443e82
@@ -3067,9 +3067,12 @@ LogicalResult OuterProductOp::verify() {
|
||||
return emitOpError("expected #1 operand dim to match result dim #1");
|
||||
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
|
||||
return emitOpError("expected #2 operand dim to match result dim #2");
|
||||
if (vRHS.isScalable() != vLHS.isScalable())
|
||||
return emitOpError("expected either all or none of vector operands #1 "
|
||||
"and #2 to be scalable");
|
||||
if (vLHS.isScalable() && !vRHS.isScalable()) {
|
||||
// This restriction reflects what's currently supported in terms of
|
||||
// scalable vectors. However, we could relax this if there's a use case.
|
||||
return emitOpError(
|
||||
"expected either both or only #2 operand dim to be scalable");
|
||||
}
|
||||
} else {
|
||||
// An AXPY operation.
|
||||
if (vRES.getRank() != 1)
|
||||
|
||||
@@ -79,21 +79,21 @@ func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf3
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_extract_contract4(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
|
||||
// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
|
||||
// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
|
||||
// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
|
||||
// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
|
||||
// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
|
||||
// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
|
||||
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
|
||||
// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
|
||||
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
|
||||
|
||||
func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
|
||||
%arg1: vector<5x7xf32>,
|
||||
@@ -104,6 +104,35 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
|
||||
return %0 : vector<3x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
|
||||
// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
|
||||
// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
|
||||
// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
|
||||
// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
|
||||
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
|
||||
// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
|
||||
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
|
||||
// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
|
||||
// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
|
||||
// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
|
||||
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
|
||||
// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
|
||||
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
|
||||
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
|
||||
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
|
||||
|
||||
// Note that only the J dimension is scalable in this example. In theory, all
|
||||
// dimensions could be be scalable, but there is no target yet for which this
|
||||
// would make sense.
|
||||
func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
|
||||
%arg1: vector<5x[7]xf32>,
|
||||
%arg2: vector<3x[7]xf32>,
|
||||
%m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
|
||||
%0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
|
||||
: vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
|
||||
return %0 : vector<3x[7]xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @matmul
|
||||
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
|
||||
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
|
||||
|
||||
@@ -21,9 +21,12 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
|
||||
%0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
|
||||
%1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
|
||||
|
||||
// expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
|
||||
// expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
|
||||
%op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_outerproduct1(%src : memref<?xf32>) {
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
// DEFINE: %{compile} = mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule\
|
||||
// DEFINE: -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage\
|
||||
// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
|
||||
// DEFINE: %{entry} =
|
||||
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e=%{entry} -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext
|
||||
|
||||
// This check whether the files compiles and generates a temporary that will be executed further down.
|
||||
// RUN: %{compile}
|
||||
|
||||
// REDEFINE: %{entry} = matmul_i32
|
||||
// RUN: %{run} | FileCheck %s --check-prefix=I32
|
||||
|
||||
// REDEFINE: %{entry} = matmul_f32
|
||||
// RUN: %{run} | FileCheck %s --check-prefix=F32
|
||||
|
||||
// NOTE: These tests are meant to complement the integration tests from:
|
||||
// * ../test-contraction.mlir
|
||||
// (tests with fixed width vectors). Rather than duplicating those tests, this
|
||||
// file focuses on excercissing scalable vectors in a few most common cases.
|
||||
|
||||
// TODO: Masks + matvec + dot product
|
||||
|
||||
#matmat_accesses = [
|
||||
affine_map<(i, j, k) -> (i, k)>,
|
||||
affine_map<(i, j, k) -> (k, j)>,
|
||||
affine_map<(i, j, k) -> (i, j)>
|
||||
]
|
||||
#matmat_trait = {
|
||||
indexing_maps = #matmat_accesses,
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
|
||||
func.func @matmul_i32() {
|
||||
// Setup vector A:
|
||||
%vector_a = arith.constant dense<123> : vector<3x5xi32>
|
||||
|
||||
// Setup vector B:
|
||||
%vector_b = arith.constant dense<123> : vector<5x[2]xi32>
|
||||
|
||||
// Setup vector C:
|
||||
%vector_c = arith.constant dense<314> : vector<3x[2]xi32>
|
||||
|
||||
// Matmul
|
||||
%0 = vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
|
||||
: vector<3x5xi32>, vector<5x[2]xi32> into vector<3x[2]xi32>
|
||||
|
||||
// Print the output
|
||||
%slice1 = vector.extract %0[0] : vector<[2]xi32> from vector<3x[2]xi32>
|
||||
// I32: ( 75959, 75959
|
||||
vector.print %slice1 : vector<[2]xi32>
|
||||
%slice2 = vector.extract %0[1] : vector<[2]xi32> from vector<3x[2]xi32>
|
||||
// I32-NEXT: ( 75959, 75959
|
||||
vector.print %slice2 : vector<[2]xi32>
|
||||
%slice3 = vector.extract %0[2] : vector<[2]xi32> from vector<3x[2]xi32>
|
||||
// I32-NEXT: ( 75959, 75959
|
||||
vector.print %slice3 : vector<[2]xi32>
|
||||
|
||||
// CHECK: SVE: END OF TEST OUTPUT
|
||||
vector.print str "SVE: END OF TEST OUTPUT"
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func.func @matmul_f32() {
|
||||
// Setup vector A:
|
||||
%vector_a = arith.constant dense<1.23> : vector<3x5xf32>
|
||||
|
||||
// Setup vector B:
|
||||
%vector_b = arith.constant dense<1.23> : vector<5x[2]xf32>
|
||||
|
||||
// Setup vector C:
|
||||
%vector_c = arith.constant dense<3.14> : vector<3x[2]xf32>
|
||||
|
||||
// Matmul
|
||||
%0 = vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
|
||||
: vector<3x5xf32>, vector<5x[2]xf32> into vector<3x[2]xf32>
|
||||
|
||||
// Print the output
|
||||
%slice1 = vector.extract %0[0] : vector<[2]xf32> from vector<3x[2]xf32>
|
||||
// F32: ( 10.7045, 10.7045
|
||||
vector.print %slice1 : vector<[2]xf32>
|
||||
%slice2 = vector.extract %0[1] : vector<[2]xf32> from vector<3x[2]xf32>
|
||||
// F32-NEXT: ( 10.7045, 10.7045
|
||||
vector.print %slice2 : vector<[2]xf32>
|
||||
%slice3 = vector.extract %0[2] : vector<[2]xf32> from vector<3x[2]xf32>
|
||||
// F32-NEXT: ( 10.7045, 10.7045
|
||||
vector.print %slice3 : vector<[2]xf32>
|
||||
|
||||
// CHECK: SVE: END OF TEST OUTPUT
|
||||
vector.print str "SVE: END OF TEST OUTPUT"
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%module_op: !transform.any_op):
|
||||
%f = transform.structured.match ops{["func.func"]} in %module_op
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
|
||||
transform.apply_patterns to %f {
|
||||
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
|
||||
} : !transform.any_op
|
||||
}
|
||||
Reference in New Issue
Block a user