[mlir][Vector] Retire one old filter-based test

Differential Revision: https://reviews.llvm.org/D146742
This commit is contained in:
Nicolas Vasilache
2023-03-23 10:47:04 -07:00
parent 5c9a26238a
commit 73bec2b2c3
2 changed files with 1 additions and 49 deletions

View File

@@ -1,7 +1,6 @@
// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL
#dotp_accesses = [
@@ -1182,32 +1181,6 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<3x2xf32>
}
// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered
// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>,
// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32>
// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
func.func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>)
-> vector<4x4xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
return %0 : vector<4x4xf32>
}
// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered
// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>,
// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32>
// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
func.func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>)
-> vector<3x4xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
return %0 : vector<3x4xf32>
}
// PARALLEL-LABEL: func @parrallel_contract_lowering
// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>

View File

@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include <optional>
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -136,11 +136,6 @@ struct TestVectorContractionLowering
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
llvm::cl::init(false)};
Option<bool> lowerToFilterOuterProduct{
*this, "vector-filter-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
"vectors of size 4."),
llvm::cl::init(false)};
Option<bool> lowerToParallelArith{
*this, "vector-parallel-arith",
llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
@@ -159,22 +154,6 @@ struct TestVectorContractionLowering
return;
}
// Test on one pattern in isolation.
if (lowerToFilterOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.add<ContractionOpToOuterProductOpLowering>(
options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) {
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
// where M is not 4.
if (op.getRhsType().getShape()[0] == 4)
return failure();
return success();
});
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
if (lowerToParallelArith) {
vector::populateVectorContractLoweringPatterns(
patterns,