mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 00:20:25 +08:00
[mlir][Vector] Retire one old filter-based test
Differential Revision: https://reviews.llvm.org/D146742
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL
|
||||
|
||||
#dotp_accesses = [
|
||||
@@ -1182,32 +1181,6 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
|
||||
return %0 : vector<3x2xf32>
|
||||
}
|
||||
|
||||
// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered
|
||||
// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>,
|
||||
// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
|
||||
// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32>
|
||||
// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
|
||||
func.func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>)
|
||||
-> vector<4x4xf32>
|
||||
{
|
||||
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
|
||||
: vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
|
||||
return %0 : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered
|
||||
// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>,
|
||||
// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
|
||||
// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32>
|
||||
// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
|
||||
func.func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>)
|
||||
-> vector<3x4xf32>
|
||||
{
|
||||
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
|
||||
: vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
|
||||
return %0 : vector<3x4xf32>
|
||||
}
|
||||
|
||||
// PARALLEL-LABEL: func @parrallel_contract_lowering
|
||||
// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
|
||||
// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <type_traits>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
@@ -136,11 +136,6 @@ struct TestVectorContractionLowering
|
||||
*this, "vector-outerproduct",
|
||||
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToFilterOuterProduct{
|
||||
*this, "vector-filter-outerproduct",
|
||||
llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
|
||||
"vectors of size 4."),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToParallelArith{
|
||||
*this, "vector-parallel-arith",
|
||||
llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
|
||||
@@ -159,22 +154,6 @@ struct TestVectorContractionLowering
|
||||
return;
|
||||
}
|
||||
|
||||
// Test on one pattern in isolation.
|
||||
if (lowerToFilterOuterProduct) {
|
||||
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
|
||||
VectorTransformsOptions options{lowering};
|
||||
patterns.add<ContractionOpToOuterProductOpLowering>(
|
||||
options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) {
|
||||
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
|
||||
// where M is not 4.
|
||||
if (op.getRhsType().getShape()[0] == 4)
|
||||
return failure();
|
||||
return success();
|
||||
});
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
return;
|
||||
}
|
||||
|
||||
if (lowerToParallelArith) {
|
||||
vector::populateVectorContractLoweringPatterns(
|
||||
patterns,
|
||||
|
||||
Reference in New Issue
Block a user