mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 05:32:28 +08:00
[MLIR][AArch64] Lower vector.contract with mixed signed/unsigned arguments to Neon FEAT_I8MM (#144698)
This commit is contained in:
@@ -13,7 +13,7 @@ namespace mlir {
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace arm_neon {
|
||||
void populateLowerContractionToSMMLAPatternPatterns(
|
||||
void populateLowerContractionToNeonI8MMPatternPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
} // namespace arm_neon
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
|
||||
populateVectorGatherLoweringPatterns(patterns);
|
||||
if (armI8MM) {
|
||||
if (armNeon)
|
||||
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
|
||||
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
|
||||
if (armSVE)
|
||||
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ using namespace mlir;
|
||||
|
||||
void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
|
||||
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
add_mlir_dialect_library(MLIRArmNeonTransforms
|
||||
LowerContractionToSMMLAPattern.cpp
|
||||
LowerContractionToNeonI8MMPattern.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRArmNeonIncGen
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
|
||||
//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@@ -6,10 +6,15 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements lowering patterns from vector.contract to
|
||||
// arm_neon.intr.smmla
|
||||
// This file implements lowering patterns from vector.contract to operations
|
||||
// that map to instructions from the Neon FEAT_I8MM extension.
|
||||
//
|
||||
//===---
|
||||
// TODO: There may be opportunities to unify this with a similar pattern
|
||||
// for SVE. See:
|
||||
// https://github.com/llvm/llvm-project/issues/145559
|
||||
// LowerContractionToSVEI8MMPattern.cpp
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
||||
@@ -37,12 +42,87 @@ static Type matchContainerType(Type element, Type container) {
|
||||
return element;
|
||||
}
|
||||
|
||||
// Get the operand of a `vector.contract`. This function is intended to abstract
|
||||
// away from the particular way a value is extended before feeding it into the
|
||||
// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
|
||||
// (for implicit sign-extension see `vector.contract` documentation).
|
||||
//
|
||||
// The template parameter `Op` indicates the extension operation (explicit or
|
||||
// implicit) for which we are checking.
|
||||
//
|
||||
// Return success only for extensions from `iN` (N <= 8) to `i32`.
|
||||
template <typename Op>
|
||||
std::optional<Value> getExtOperand(Value v) {
|
||||
|
||||
static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
|
||||
"Must be instantiated with either sign- or zero- extension op");
|
||||
|
||||
// If the operand is not defined by an explicit extend operation of the
|
||||
// accepted operation type allow for an implicit sign-extension.
|
||||
auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
|
||||
if (!extOp) {
|
||||
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
|
||||
auto eltTy = cast<VectorType>(v.getType()).getElementType();
|
||||
if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
|
||||
return {};
|
||||
return v;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
// If the operand is defined by an explicit extend operation of the accepted
|
||||
// operation type, check it's extended from `iN` (N <= 8) to `i32`.
|
||||
auto inOp = extOp.getIn();
|
||||
auto inTy = dyn_cast<VectorType>(inOp.getType());
|
||||
if (!inTy)
|
||||
return {};
|
||||
auto inEltTy = inTy.getElementType();
|
||||
if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
|
||||
return {};
|
||||
|
||||
auto outTy = dyn_cast<VectorType>(extOp.getType());
|
||||
if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
|
||||
return {};
|
||||
|
||||
return inOp;
|
||||
}
|
||||
|
||||
// Designate the operation (resp. instruction) used to do sub-tile matrix
|
||||
// multiplications.
|
||||
enum class MMLA {
|
||||
Signed, // smmla
|
||||
Unsigned, // ummla
|
||||
Mixed, // usmmla
|
||||
MixedSwapped // usmmla with LHS and RHS swapped
|
||||
};
|
||||
|
||||
// Create the matrix mulitply and accumulate operation according to `op`.
|
||||
Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
|
||||
mlir::Type accType, Value acc, Value lhs, Value rhs) {
|
||||
switch (op) {
|
||||
case MMLA::Signed:
|
||||
return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
|
||||
rhs);
|
||||
case MMLA::Unsigned:
|
||||
return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
|
||||
rhs);
|
||||
case MMLA::Mixed:
|
||||
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
|
||||
rhs);
|
||||
case MMLA::MixedSwapped:
|
||||
// The accumulator comes transposed and the result will be transposed
|
||||
// later, so all we have to do here is swap the operands.
|
||||
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
|
||||
lhs);
|
||||
}
|
||||
}
|
||||
|
||||
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
|
||||
/// any vector.contract into multiple smmla instructions with unrolling so long
|
||||
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
|
||||
/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
|
||||
/// necessary, a single smmla instruction is emitted.
|
||||
class LowerContractionToSMMLAPattern
|
||||
class LowerContractionToNeonI8MMPattern
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
@@ -88,39 +168,64 @@ public:
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check two extsi inputs Rhs Lhs for contract.
|
||||
arith::ExtSIOp origLhsExtOp =
|
||||
dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
|
||||
arith::ExtSIOp origRhsExtOp =
|
||||
dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
|
||||
if (!origLhsExtOp || !origRhsExtOp) {
|
||||
return failure();
|
||||
// Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
|
||||
// values before the extension. All four signed/unsigned combinations for
|
||||
// input operands are supported, but they are lowered to different
|
||||
// operations. Determine which is the appropriate operation to lower to.
|
||||
MMLA mmlaOp = MMLA::Signed;
|
||||
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
|
||||
if (!maybeLhs) {
|
||||
mmlaOp = MMLA::Unsigned;
|
||||
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
|
||||
}
|
||||
if (!maybeLhs)
|
||||
return failure();
|
||||
|
||||
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
|
||||
if (maybeRhs) {
|
||||
if (mmlaOp == MMLA::Unsigned)
|
||||
mmlaOp = MMLA::Mixed;
|
||||
} else {
|
||||
if (mmlaOp == MMLA::Signed)
|
||||
mmlaOp = MMLA::MixedSwapped;
|
||||
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
|
||||
}
|
||||
if (!maybeRhs)
|
||||
return failure();
|
||||
|
||||
Value origLhs = *maybeLhs;
|
||||
Value origRhs = *maybeRhs;
|
||||
|
||||
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
|
||||
// following neon instruction. Check inputs for extsi are <=i8
|
||||
Value extsiLhs;
|
||||
Value extsiRhs;
|
||||
if (auto lhsExtInType =
|
||||
dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
|
||||
Value extLhs;
|
||||
Value extRhs;
|
||||
if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
|
||||
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
|
||||
Type targetLhsExtTy =
|
||||
matchContainerType(rewriter.getI8Type(), lhsExtInType);
|
||||
extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
|
||||
origLhsExtOp.getIn());
|
||||
if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
|
||||
extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
|
||||
origLhs);
|
||||
else
|
||||
extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
|
||||
origLhs);
|
||||
}
|
||||
}
|
||||
if (auto rhsExtInType =
|
||||
dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
|
||||
if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
|
||||
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
|
||||
Type targetRhsExtTy =
|
||||
matchContainerType(rewriter.getI8Type(), rhsExtInType);
|
||||
extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
|
||||
origRhsExtOp.getIn());
|
||||
if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
|
||||
extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
|
||||
origRhs);
|
||||
else
|
||||
extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
|
||||
origRhs);
|
||||
}
|
||||
}
|
||||
|
||||
if (!extsiLhs || !extsiRhs) {
|
||||
if (!extLhs || !extRhs) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -155,11 +260,11 @@ public:
|
||||
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
|
||||
SmallVector<int64_t> lhsOffsets =
|
||||
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
|
||||
Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
|
||||
Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
|
||||
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
|
||||
SmallVector<int64_t> rhsOffsets =
|
||||
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
|
||||
Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
|
||||
Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
|
||||
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
|
||||
SmallVector<int64_t> accOffsets =
|
||||
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
|
||||
@@ -191,6 +296,13 @@ public:
|
||||
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
|
||||
}
|
||||
|
||||
// Transpose ACC if doing signed by unsigned multiplication, because we're
|
||||
// using the instruction for unsigned by signed multiplication with
|
||||
// reversed operands.
|
||||
if (mmlaOp == MMLA::MixedSwapped)
|
||||
tiledAcc = rewriter.create<vector::TransposeOp>(
|
||||
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
|
||||
|
||||
// Collapse tiled operands to 1D vectors required by smmla intrinsic
|
||||
auto collapsedInputType =
|
||||
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
|
||||
@@ -211,15 +323,21 @@ public:
|
||||
}
|
||||
|
||||
// Insert contract op
|
||||
kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
|
||||
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
|
||||
collapsedRhs);
|
||||
kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
|
||||
collapsedRes, collapsedLhs, collapsedRhs);
|
||||
|
||||
// Reshape output back to 2D
|
||||
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
|
||||
kAcc.getLoc(), tiledAcc.getType(), kAcc);
|
||||
|
||||
// With vecmat, only one row of tiled ACC can be inserted into file result
|
||||
// Because of the reversed operands the result is obtained transposed.
|
||||
// Transpose it back,
|
||||
if (mmlaOp == MMLA::MixedSwapped)
|
||||
tiledRes = rewriter.create<vector::TransposeOp>(
|
||||
loc, tiledRes, ArrayRef<int64_t>({1, 0}));
|
||||
|
||||
// With vecmat, only one row of tiled ACC can be inserted into the final
|
||||
// result
|
||||
if (isVecmat) {
|
||||
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
|
||||
}
|
||||
@@ -239,8 +357,8 @@ public:
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
|
||||
void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
|
||||
patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
|
||||
//===- LowerContractionToSVEI8MMPattern.cpp - Contract to I8MM --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@@ -9,6 +9,11 @@
|
||||
// This file implements lowering patterns from vector.contract to operations
|
||||
// that map to instructions from the SVE FEAT_I8MM extension.
|
||||
//
|
||||
// TODO: There may be opportunities to unify this with a similar pattern
|
||||
// for Neon. See:
|
||||
// https://github.com/llvm/llvm-project/issues/145559
|
||||
// LowerContractionToNeonI8MMPattern.cpp
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
|
||||
@@ -17,14 +17,28 @@ func.func @vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: vector_arm_neon_same_types
|
||||
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi8>, %[[A2:.*]]: vector<2x2xi32>
|
||||
// CHECK-DAG: %[[D0:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A1]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
|
||||
// CHECK-DAG: %[[D3:.*]] = arm_neon.intr.smmla %[[D2]], %[[D0]], %[[D1]] : vector<16xi8> to vector<4xi32>
|
||||
// CHECK-DAG: %[[D4:.*]] = vector.shape_cast %[[D3]] : vector<4xi32> to vector<2x2xi32>
|
||||
func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
|
||||
// CHECK-LABEL: vector_arm_neon_implicit_extsi
|
||||
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
|
||||
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
|
||||
// CHECK: %[[M:.+]] = arm_neon.intr.smmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
|
||||
// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
|
||||
func.func @vector_arm_neon_implicit_extsi(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
|
||||
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi8>, vector<2x8xi8> into vector<2x2xi32>
|
||||
return %res : vector<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: vector_arm_neon_signed_signed
|
||||
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
|
||||
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
|
||||
// CHECK: %[[M:.+]] = arm_neon.intr.smmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
|
||||
// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
|
||||
func.func @vector_arm_neon_signed_signed(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
|
||||
%lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
|
||||
%rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
|
||||
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
|
||||
@@ -33,11 +47,51 @@ func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: vector_arm_neon_without_extsi
|
||||
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi32>, %[[A1:.*]]: vector<2x8xi32>, %[[A2:.*]]: vector<2x2xi32>
|
||||
// CHECK-DAG: %[[D0:.*]] = vector.contract
|
||||
func.func @vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
|
||||
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
|
||||
// CHECK-LABEL: vector_arm_neon_unsigned_signed
|
||||
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
|
||||
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
|
||||
// CHECK: %[[M:.+]] = arm_neon.intr.usmmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
|
||||
// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
|
||||
func.func @vector_arm_neon_unsigned_signed(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
|
||||
%lhs_extsi = arith.extui %lhs : vector<2x8xi8> to vector<2x8xi32>
|
||||
%rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
|
||||
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
|
||||
return %res : vector<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: vector_arm_neon_unsigned_unsigned
|
||||
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
|
||||
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
|
||||
// CHECK: %[[M:.+]] = arm_neon.intr.ummla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
|
||||
// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
|
||||
func.func @vector_arm_neon_unsigned_unsigned(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
|
||||
%lhs_extsi = arith.extui %lhs : vector<2x8xi8> to vector<2x8xi32>
|
||||
%rhs_extsi = arith.extui %rhs : vector<2x8xi8> to vector<2x8xi32>
|
||||
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
|
||||
return %res : vector<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: vector_arm_neon_signed_unsigned
|
||||
// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
|
||||
// CHECK: %[[ACC_T:.+]] = vector.transpose %[[ACC]], [1, 0] : vector<2x2xi32> to vector<2x2xi32>
|
||||
// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
|
||||
// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC_T]] : vector<2x2xi32> to vector<4xi32>
|
||||
// CHECK: %[[M:.+]] = arm_neon.intr.usmmla %[[A]], %[[R]], %[[L]] : vector<16xi8> to vector<4xi32>
|
||||
// CHECK: %[[OUT_T:.+]] = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
|
||||
// CHECK: %{{.+}} = vector.transpose %[[OUT_T]], [1, 0] : vector<2x2xi32> to vector<2x2xi32>
|
||||
func.func @vector_arm_neon_signed_unsigned(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
|
||||
%lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
|
||||
%rhs_extsi = arith.extui %rhs : vector<2x8xi8> to vector<2x8xi32>
|
||||
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
|
||||
return %res : vector<2x2xi32>
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user