[mlir][RFC] Add scalable dimensions to VectorType

With VectorType supporting scalable dimensions, we don't need many of
the operations currently present in ArmSVE, like mask generation and
basic arithmetic instructions. Therefore, this patch also gets
rid of those.

Having built-in scalable vector support also simplifies the lowering of
scalable vector dialects down to LLVMIR.

Scalable dimensions are indicated with the scalable dimensions
between square brackets:

        vector<[4]xf32>

Is a scalable vector of 4 single precission floating point elements.

More generally, a VectorType can have a set of fixed-length dimensions
followed by a set of scalable dimensions:

        vector<2x[4x4]xf32>

Is a vector with 2 scalable 4x4 vectors of single precission floating
point elements.

The scale of the scalable dimensions can be obtained with the Vector
operation:

        %vs = vector.vscale

This change is being discussed in the discourse RFC:

https://llvm.discourse.group/t/rfc-add-built-in-support-for-scalable-vector-types/4484

Differential Revision: https://reviews.llvm.org/D111819
This commit is contained in:
Javier Setoain
2021-10-12 14:26:01 +01:00
parent 7161aa06ef
commit a4830d14ed
33 changed files with 965 additions and 1113 deletions

View File

@@ -16,7 +16,6 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td"
//===----------------------------------------------------------------------===//
// ArmSVE dialect definition
@@ -27,69 +26,11 @@ def ArmSVE_Dialect : Dialect {
let cppNamespace = "::mlir::arm_sve";
let summary = "Basic dialect to target Arm SVE architectures";
let description = [{
This dialect contains the definitions necessary to target Arm SVE scalable
vector operations, including a scalable vector type and intrinsics for
some Arm SVE instructions.
}];
let useDefaultTypePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
// ArmSVE type definitions
//===----------------------------------------------------------------------===//
def ArmSVE_ScalableVectorType : DialectType<ArmSVE_Dialect,
CPred<"$_self.isa<ScalableVectorType>()">,
"scalable vector type">,
BuildableType<"$_builder.getType<ScalableVectorType>()"> {
let description = [{
`arm_sve.vector` represents vectors that will be processed by a scalable
vector architecture.
This dialect contains the definitions necessary to target specific Arm SVE
scalable vector operations.
}];
}
class ArmSVE_Type<string name> : TypeDef<ArmSVE_Dialect, name> { }
def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
let mnemonic = "vector";
let summary = "Scalable vector type";
let description = [{
A type representing scalable length SIMD vectors. Unlike fixed-length SIMD
vectors, whose size is constant and known at compile time, scalable
vectors' length is constant but determined by the specific hardware at
run time.
}];
let parameters = (ins
ArrayRefParameter<"int64_t", "Vector shape">:$shape,
"Type":$elementType
);
let extraClassDeclaration = [{
bool hasStaticShape() const {
return llvm::none_of(getShape(), ShapedType::isDynamic);
}
int64_t getNumElements() const {
assert(hasStaticShape() &&
"cannot get element count of dynamic shaped type");
ArrayRef<int64_t> shape = getShape();
int64_t num = 1;
for (auto dim : shape)
num *= dim;
return num;
}
}];
}
//===----------------------------------------------------------------------===//
// Additional LLVM type constraints
//===----------------------------------------------------------------------===//
def LLVMScalableVectorType :
Type<CPred<"$_self.isa<::mlir::LLVM::LLVMScalableVectorType>()">,
"LLVM dialect scalable vector type">;
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -97,16 +38,6 @@ def LLVMScalableVectorType :
class ArmSVE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<ArmSVE_Dialect, mnemonic, traits> {}
class ArmSVE_NonSVEIntrUnaryOverloadedOp<string mnemonic,
list<OpTrait> traits =[]> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/mnemonic,
/*string enumName=*/mnemonic,
/*list<int> overloadedResults=*/[0],
/*list<int> overloadedOperands=*/[], // defined by result overload
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<OpTrait> traits = []> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
@@ -117,42 +48,6 @@ class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
class ScalableFOp<string mnemonic, string op_description,
list<OpTrait> traits = []> :
ArmSVE_Op<mnemonic, !listconcat(traits,
[AllTypesMatch<["src1", "src2", "dst"]>])> {
let summary = op_description # " for scalable vectors of floats";
let description = [{
The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and
returns one scalable vector with the result of the }] # op_description # [{.
}];
let arguments = (ins
ScalableVectorOf<[AnyFloat]>:$src1,
ScalableVectorOf<[AnyFloat]>:$src2
);
let results = (outs ScalableVectorOf<[AnyFloat]>:$dst);
let assemblyFormat =
"$src1 `,` $src2 attr-dict `:` type($src1)";
}
class ScalableIOp<string mnemonic, string op_description,
list<OpTrait> traits = []> :
ArmSVE_Op<mnemonic, !listconcat(traits,
[AllTypesMatch<["src1", "src2", "dst"]>])> {
let summary = op_description # " for scalable vectors of integers";
let description = [{
The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and
returns one scalable vector with the result of the }] # op_description # [{.
}];
let arguments = (ins
ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
ScalableVectorOf<[I8, I16, I32, I64]>:$src2
);
let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst);
let assemblyFormat =
"$src1 `,` $src2 attr-dict `:` type($src1)";
}
class ScalableMaskedFOp<string mnemonic, string op_description,
list<OpTrait> traits = []> :
ArmSVE_Op<mnemonic, !listconcat(traits,
@@ -325,74 +220,6 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def VectorScaleOp : ArmSVE_Op<"vector_scale",
[NoSideEffect]> {
let summary = "Load vector scale size";
let description = [{
The vector_scale op returns the scale of the scalable vectors, a positive
integer value that is constant at runtime but unknown at compile time.
The scale of the vector indicates the multiplicity of the vectors and
vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to
vector_scale consecutive vector<4xi32>; and an operation on an
!arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale
times, once on each <4xi32> segment of the scalable vector. The vector_scale
op can be used to calculate the step in vector-length agnostic (VLA) loops.
}];
let results = (outs Index:$res);
let assemblyFormat =
"attr-dict `:` type($res)";
}
def ScalableLoadOp : ArmSVE_Op<"load">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base, Index:$index)>,
Results<(outs ScalableVectorOf<[AnyType]>:$result)> {
let summary = "Load scalable vector from memory";
let description = [{
Load a slice of memory into a scalable vector.
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
}];
let assemblyFormat = "$base `[` $index `]` attr-dict `:` "
"type($result) `from` type($base)";
}
def ScalableStoreOp : ArmSVE_Op<"store">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base, Index:$index,
ScalableVectorOf<[AnyType]>:$value)> {
let summary = "Store scalable vector into memory";
let description = [{
Store a scalable vector on a slice of memory.
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
}];
let assemblyFormat = "$value `,` $base `[` $index `]` attr-dict `:` "
"type($value) `to` type($base)";
}
def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>;
def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>;
def ScalableSubIOp : ScalableIOp<"subi", "subtraction">;
def ScalableSubFOp : ScalableFOp<"subf", "subtraction">;
def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>;
def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>;
def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">;
def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">;
def ScalableDivFOp : ScalableFOp<"divf", "division">;
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
[Commutative]>;
@@ -417,189 +244,56 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
//===----------------------------------------------------------------------===//
// ScalableCmpFOp
//===----------------------------------------------------------------------===//
def ScalableCmpFOp : ArmSVE_Op<"cmpf", [NoSideEffect, SameTypeOperands,
TypesMatchWith<"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">]> {
let summary = "floating-point comparison operation for scalable vectors";
let description = [{
The `arm_sve.cmpf` operation compares two scalable vectors of floating point
elements according to the float comparison rules and the predicate specified
by the respective attribute. The predicate defines the type of comparison:
(un)orderedness, (in)equality and signed less/greater than (or equal to) as
well as predicates that are always true or false. The result is a scalable
vector of i1 elements. Unlike `arm_sve.cmpi`, the operands are always
treated as signed. The u prefix indicates *unordered* comparison, not
unsigned comparison, so "une" means unordered not equal. For the sake of
readability by humans, custom assembly form for the operation uses a
string-typed attribute for the predicate. The value of this attribute
corresponds to lower-cased name of the predicate constant, e.g., "one" means
"ordered not equal". The string representation of the attribute is merely a
syntactic sugar and is converted to an integer attribute by the parser.
Example:
```mlir
%r = arm_sve.cmpf oeq, %0, %1 : !arm_sve.vector<4xf32>
```
}];
let arguments = (ins
Arith_CmpFPredicateAttr:$predicate,
ScalableVectorOf<[AnyFloat]>:$lhs,
ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar
);
let results = (outs ScalableVectorOf<[I1]>:$result);
let builders = [
OpBuilder<(ins "arith::CmpFPredicate":$predicate, "Value":$lhs,
"Value":$rhs), [{
buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs);
}]>];
let extraClassDeclaration = [{
static StringRef getPredicateAttrName() { return "predicate"; }
static arith::CmpFPredicate getPredicateByName(StringRef name);
arith::CmpFPredicate getPredicate() {
return (arith::CmpFPredicate) (*this)->getAttrOfType<IntegerAttr>(
getPredicateAttrName()).getInt();
}
}];
let verifier = [{ return success(); }];
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
//===----------------------------------------------------------------------===//
// ScalableCmpIOp
//===----------------------------------------------------------------------===//
def ScalableCmpIOp : ArmSVE_Op<"cmpi", [NoSideEffect, SameTypeOperands,
TypesMatchWith<"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">]> {
let summary = "integer comparison operation for scalable vectors";
let description = [{
The `arm_sve.cmpi` operation compares two scalable vectors of integer
elements according to the predicate specified by the respective attribute.
The predicate defines the type of comparison:
- equal (mnemonic: `"eq"`; integer value: `0`)
- not equal (mnemonic: `"ne"`; integer value: `1`)
- signed less than (mnemonic: `"slt"`; integer value: `2`)
- signed less than or equal (mnemonic: `"sle"`; integer value: `3`)
- signed greater than (mnemonic: `"sgt"`; integer value: `4`)
- signed greater than or equal (mnemonic: `"sge"`; integer value: `5`)
- unsigned less than (mnemonic: `"ult"`; integer value: `6`)
- unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`)
- unsigned greater than (mnemonic: `"ugt"`; integer value: `8`)
- unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`)
Example:
```mlir
%r = arm_sve.cmpi uge, %0, %1 : !arm_sve.vector<4xi32>
```
}];
let arguments = (ins
Arith_CmpIPredicateAttr:$predicate,
ScalableVectorOf<[I8, I16, I32, I64]>:$lhs,
ScalableVectorOf<[I8, I16, I32, I64]>:$rhs
);
let results = (outs ScalableVectorOf<[I1]>:$result);
let builders = [
OpBuilder<(ins "arith::CmpIPredicate":$predicate, "Value":$lhs,
"Value":$rhs), [{
buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs);
}]>];
let extraClassDeclaration = [{
static StringRef getPredicateAttrName() { return "predicate"; }
static arith::CmpIPredicate getPredicateByName(StringRef name);
arith::CmpIPredicate getPredicate() {
return (arith::CmpIPredicate) (*this)->getAttrOfType<IntegerAttr>(
getPredicateAttrName()).getInt();
}
}];
let verifier = [{ return success(); }];
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def UdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"udot">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def ScalableMaskedAddIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"add">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def ScalableMaskedAddFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fadd">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def ScalableMaskedMulIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"mul">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def ScalableMaskedMulFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fmul">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def ScalableMaskedSubIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sub">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def ScalableMaskedSubFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fsub">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def ScalableMaskedSDivIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def ScalableMaskedUDivIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"udiv">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
def ScalableMaskedDivFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
LLVMScalableVectorType)>;
def VectorScaleIntrOp:
ArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
#endif // ARMSVE_OPS

View File

@@ -21,9 +21,6 @@
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVE.h.inc"

View File

@@ -1,53 +0,0 @@
//===-- ArmSVEOpBase.td - Base op definitions for ArmSVE ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the base operation definition file for ArmSVE scalable vector types.
//
//===----------------------------------------------------------------------===//
#ifndef ARMSVE_OP_BASE
#define ARMSVE_OP_BASE
//===----------------------------------------------------------------------===//
// ArmSVE scalable vector type constraints
//===----------------------------------------------------------------------===//
def IsScalableVectorTypePred :
CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">;
class ScalableVectorOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsScalableVectorTypePred,
"$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()",
"scalable vector">;
// Whether the number of elements of a scalable vector is from the given
// `allowedLengths` list
class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
And<[IsScalableVectorTypePred,
Or<!foreach(allowedlength, allowedLengths, CPred<
[{$_self.cast<::mlir::arm_sve::ScalableVectorType>().getNumElements() == }]
# allowedlength>)>]>;
// Any scalable vector where the number of elements is from the given
// `allowedLengths` list
class ScalableVectorOfLength<list<int> allowedLengths> : Type<
IsScalableVectorOfLengthPred<allowedLengths>,
" of length " # !interleave(allowedLengths, "/"),
"::mlir::arm_sve::ScalableVectorType">;
// Any scalable vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : Type<
And<[ScalableVectorOf<allowedTypes>.predicate,
ScalableVectorOfLength<allowedLengths>.predicate]>,
ScalableVectorOf<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::arm_sve::ScalableVectorType">;
#endif // ARMSVE_OP_BASE

View File

@@ -1734,7 +1734,9 @@ def LLVM_masked_compressstore
let arguments = (ins LLVM_Type, LLVM_Type, LLVM_Type);
}
//
/// Create a call to vscale intrinsic.
def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>;
// Atomic operations.
//

View File

@@ -483,10 +483,22 @@ Type getVectorElementType(Type type);
/// Returns the element count of any LLVM-compatible vector type.
llvm::ElementCount getVectorNumElements(Type type);
/// Returns whether a vector type is scalable or not.
bool isScalableVectorType(Type vectorType);
/// Creates an LLVM dialect-compatible vector type with the given element type
/// and length.
Type getVectorType(Type elementType, unsigned numElements,
bool isScalable = false);
/// Creates an LLVM dialect-compatible type with the given element type and
/// length.
Type getFixedVectorType(Type elementType, unsigned numElements);
/// Creates an LLVM dialect-compatible type with the given element type and
/// length.
Type getScalableVectorType(Type elementType, unsigned numElements);
/// Returns the size of the given primitive LLVM dialect-compatible type
/// (including vectors) in bits, for example, the size of i16 is 16 and
/// the size of vector<4xi16> is 64. Returns 0 for non-primitive

View File

@@ -2383,4 +2383,36 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect,
let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
}
//===----------------------------------------------------------------------===//
// VectorScaleOp
//===----------------------------------------------------------------------===//
// TODO: In the future, we might want to have scalable vectors with different
// scales for different dimensions. E.g.: vector<[16]x[16]xf32>, in
// which case we might need to add an index to 'vscale' to select one
// of them. In order to support GPUs, we might also want to differentiate
// between a 'global' scale, a scale that's fixed throughout the
// execution, and a 'local' scale that is fixed but might vary with each
// call to the function. For that, it might be useful to have a
// 'vector.scale.global' and a 'vector.scale.local' operation.
def VectorScaleOp : Vector_Op<"vscale",
[NoSideEffect]> {
let summary = "Load vector scale size";
let description = [{
The `vscale` op returns the scale of the scalable vectors, a positive
integer value that is constant at runtime but unknown at compile-time.
The scale of the vector indicates the multiplicity of the vectors and
vector operations. For example, a `vector<[4]xi32>` is equivalent to
`vscale` consecutive `vector<4xi32>`; and an operation on a
`vector<[4]xi32>` is equivalent to performing that operation `vscale`
times, once on each `<4xi32>` segment of the scalable vector. The `vscale`
op can be used to calculate the step in vector-length agnostic (VLA) loops.
Right now we only support one contiguous set of scalable dimensions, all of
them grouped and scaled with the value returned by 'vscale'.
}];
let results = (outs Index:$res);
let assemblyFormat = "attr-dict";
let verifier = ?;
}
#endif // VECTOR_OPS

View File

@@ -315,13 +315,18 @@ class VectorType::Builder {
public:
/// Build from another VectorType.
explicit Builder(VectorType other)
: shape(other.getShape()), elementType(other.getElementType()) {}
: shape(other.getShape()), elementType(other.getElementType()),
numScalableDims(other.getNumScalableDims()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType)
: shape(shape), elementType(elementType) {}
Builder(ArrayRef<int64_t> shape, Type elementType,
unsigned numScalableDims = 0)
: shape(shape), elementType(elementType),
numScalableDims(numScalableDims) {}
Builder &setShape(ArrayRef<int64_t> newShape) {
Builder &setShape(ArrayRef<int64_t> newShape,
unsigned newNumScalableDims = 0) {
numScalableDims = newNumScalableDims;
shape = newShape;
return *this;
}
@@ -334,6 +339,8 @@ public:
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
if (pos >= shape.size() - numScalableDims)
numScalableDims--;
if (storage.empty())
storage.append(shape.begin(), shape.end());
storage.erase(storage.begin() + pos);
@@ -347,7 +354,7 @@ public:
operator Type() {
if (shape.empty())
return elementType;
return VectorType::get(shape, elementType);
return VectorType::get(shape, elementType, numScalableDims);
}
private:
@@ -355,6 +362,7 @@ private:
// Owning shape data for copy-on-write operations.
SmallVector<int64_t> storage;
Type elementType;
unsigned numScalableDims;
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of

View File

@@ -892,16 +892,21 @@ def Builtin_Vector : Builtin_Type<"Vector", [
Syntax:
```
vector-type ::= `vector` `<` static-dimension-list vector-element-type `>`
vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
vector-element-type ::= float-type | integer-type | index-type
static-dimension-list ::= (decimal-literal `x`)*
vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
static-dim-list ::= decimal-literal (`x` decimal-literal)*
```
The vector type represents a SIMD style vector, used by target-specific
operation sets like AVX. While the most common use is for 1D vectors (e.g.
vector<16 x f32>) we also support multidimensional registers on targets that
support them (like TPUs).
The vector type represents a SIMD style vector used by target-specific
operation sets like AVX or SVE. While the most common use is for 1D
vectors (e.g. vector<16 x f32>) we also support multidimensional registers
on targets that support them (like TPUs). The dimensions of a vector type
can be fixed-length, scalable, or a combination of the two. The scalable
dimensions in a vector are indicated between square brackets ([ ]), and
all fixed-length dimensions, if present, must precede the set of scalable
dimensions. That is, a `vector<2x[4]xf32>` is valid, but `vector<[4]x2xf32>`
is not.
Vector shapes must be positive decimal integers. 0D vectors are allowed by
omitting the dimension: `vector<f32>`.
@@ -913,19 +918,31 @@ def Builtin_Vector : Builtin_Type<"Vector", [
Examples:
```mlir
// A 2D fixed-length vector of 3x42 i32 elements.
vector<3x42xi32>
// A 1D scalable-length vector that contains a multiple of 4 f32 elements.
vector<[4]xf32>
// A 2D scalable-length vector that contains a multiple of 2x8 i8 elements.
vector<[2x8]xf32>
// A 2D mixed fixed/scalable vector that contains 4 scalable vectors of 4 f32 elements.
vector<4x[4]xf32>
```
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType
"Type":$elementType,
"unsigned":$numScalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType
"ArrayRef<int64_t>":$shape, "Type":$elementType,
CArg<"unsigned", "0">:$numScalableDims
), [{
return $_get(elementType.getContext(), shape, elementType);
return $_get(elementType.getContext(), shape, elementType,
numScalableDims);
}]>
];
let extraClassDeclaration = [{
@@ -933,13 +950,18 @@ def Builtin_Vector : Builtin_Type<"Vector", [
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
/// Returns true of the given type can be used as an element of a vector
/// Returns true if the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
static bool isValidElementType(Type t) {
return t.isa<IntegerType, IndexType, FloatType>();
}
/// Returns true if the vector contains scalable dimensions.
bool isScalable() const {
return getNumScalableDims() > 0;
}
/// Get or create a new VectorType with the same shape as `this` and an
/// element type of bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.

View File

@@ -216,6 +216,14 @@ def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">,
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
// Whether a type is a fixed-length VectorType.
def IsFixedVectorTypePred : CPred<[{$_self.isa<::mlir::VectorType>() &&
!$_self.cast<VectorType>().isScalable()}]>;
// Whether a type is a scalable VectorType.
def IsScalableVectorTypePred : CPred<[{$_self.isa<::mlir::VectorType>() &&
$_self.cast<VectorType>().isScalable()}]>;
// Whether a type is a TensorType.
def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">;
@@ -611,6 +619,14 @@ class VectorOfAnyRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
"::mlir::VectorType">;
class FixedVectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
"fixed-length vector", "::mlir::VectorType">;
class ScalableVectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
"scalable vector", "::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
// `allowedRanks` list
class IsVectorOfRankPred<list<int> allowedRanks> :
@@ -643,6 +659,24 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
== }]
# allowedlength>)>]>;
// Whether the number of elements of a fixed-length vector is from the given
// `allowedLengths` list
class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
And<[IsFixedVectorTypePred,
Or<!foreach(allowedlength, allowedLengths,
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()
== }]
# allowedlength>)>]>;
// Whether the number of elements of a scalable vector is from the given
// `allowedLengths` list
class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
And<[IsScalableVectorTypePred,
Or<!foreach(allowedlength, allowedLengths,
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()
== }]
# allowedlength>)>]>;
// Any vector where the number of elements is from the given
// `allowedLengths` list
class VectorOfLength<list<int> allowedLengths> : Type<
@@ -650,6 +684,20 @@ class VectorOfLength<list<int> allowedLengths> : Type<
" of length " # !interleave(allowedLengths, "/"),
"::mlir::VectorType">;
// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list
class FixedVectorOfLength<list<int> allowedLengths> : Type<
IsFixedVectorOfLengthPred<allowedLengths>,
" of length " # !interleave(allowedLengths, "/"),
"::mlir::VectorType">;
// Any scalable vector where the number of elements is from the given
// `allowedLengths` list
class ScalableVectorOfLength<list<int> allowedLengths> : Type<
IsScalableVectorOfLengthPred<allowedLengths>,
" of length " # !interleave(allowedLengths, "/"),
"::mlir::VectorType">;
// Any vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes`
// list
@@ -660,10 +708,34 @@ class VectorOfLengthAndType<list<int> allowedLengths,
VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class FixedVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : Type<
And<[FixedVectorOf<allowedTypes>.predicate,
FixedVectorOfLength<allowedLengths>.predicate]>,
FixedVectorOf<allowedTypes>.summary #
FixedVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
// Any scalable vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : Type<
And<[ScalableVectorOf<allowedTypes>.predicate,
ScalableVectorOfLength<allowedLengths>.predicate]>,
ScalableVectorOf<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
def AnyFixedVector : FixedVectorOf<[AnyType]>;
def AnyScalableVector : ScalableVectorOf<[AnyType]>;
// Shaped types.
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",

View File

@@ -411,7 +411,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
return {};
if (type.getShape().empty())
return VectorType::get({1}, elementType);
Type vectorType = VectorType::get(type.getShape().back(), elementType);
Type vectorType = VectorType::get(type.getShape().back(), elementType,
type.getNumScalableDims());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
auto shape = type.getShape();

View File

@@ -26,13 +26,21 @@ using namespace mlir::vector;
// Helper to reduce vector type by one rank at front.
static VectorType reducedVectorTypeFront(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
unsigned numScalableDims = tp.getNumScalableDims();
if (tp.getShape().size() == numScalableDims)
--numScalableDims;
return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
numScalableDims);
}
// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().take_back(), tp.getElementType());
unsigned numScalableDims = tp.getNumScalableDims();
if (numScalableDims > 0)
--numScalableDims;
return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
numScalableDims);
}
// Helper that picks the proper sequence for inserting.
@@ -112,6 +120,10 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
namespace {
/// Trivial Vector to LLVM conversions
using VectorScaleOpConversion =
OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
/// Conversion pattern for a vector.bitcast.
class VectorBitCastOpConversion
: public ConvertOpToLLVMPattern<vector::BitCastOp> {
@@ -1064,7 +1076,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedLoadOp,
vector::MaskedLoadOpAdaptor>,

View File

@@ -999,7 +999,8 @@ static Type getI1SameShape(Type type) {
if (type.isa<UnrankedTensorType>())
return UnrankedTensorType::get(i1Type);
if (auto vectorType = type.dyn_cast<VectorType>())
return VectorType::get(vectorType.getShape(), i1Type);
return VectorType::get(vectorType.getShape(), i1Type,
vectorType.getNumScalableDims());
return i1Type;
}

View File

@@ -25,12 +25,6 @@ using namespace arm_sve;
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.cpp.inc"
static Type getI1SameShape(Type type);
static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
arith::CmpIPredicate predicate, Value lhs,
Value rhs);
static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
arith::CmpFPredicate predicate, Value lhs,
Value rhs);
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
@@ -43,31 +37,6 @@ void ArmSVEDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
// ScalableVectorType
//===----------------------------------------------------------------------===//
void ScalableVectorType::print(AsmPrinter &printer) const {
printer << "<";
for (int64_t dim : getShape())
printer << dim << 'x';
printer << getElementType() << '>';
}
Type ScalableVectorType::parse(AsmParser &parser) {
SmallVector<int64_t> dims;
Type eltType;
if (parser.parseLess() ||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
parser.parseType(eltType) || parser.parseGreater())
return {};
return ScalableVectorType::get(eltType.getContext(), dims, eltType);
}
//===----------------------------------------------------------------------===//
@@ -77,30 +46,8 @@ Type ScalableVectorType::parse(AsmParser &parser) {
// Return the scalable vector of the same shape and containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto sVectorType = type.dyn_cast<ScalableVectorType>())
return ScalableVectorType::get(type.getContext(), sVectorType.getShape(),
i1Type);
if (auto sVectorType = type.dyn_cast<VectorType>())
return VectorType::get(sVectorType.getShape(), i1Type,
sVectorType.getNumScalableDims());
return nullptr;
}
//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
arith::CmpFPredicate predicate, Value lhs,
Value rhs) {
result.addOperands({lhs, rhs});
result.types.push_back(getI1SameShape(lhs.getType()));
result.addAttribute(ScalableCmpFOp::getPredicateAttrName(),
build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
}
static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
arith::CmpIPredicate predicate, Value lhs,
Value rhs) {
result.addOperands({lhs, rhs});
result.types.push_back(getI1SameShape(lhs.getType()));
result.addAttribute(ScalableCmpIOp::getPredicateAttrName(),
build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
}

View File

@@ -18,29 +18,6 @@
using namespace mlir;
using namespace mlir::arm_sve;
// Extract an LLVM IR type from the LLVM IR dialect type.
static Type unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
if (!LLVM::isCompatibleType(type))
emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type");
return type;
}
static Optional<Type>
convertScalableVectorTypeToLLVM(ScalableVectorType svType,
LLVMTypeConverter &converter) {
auto elementType = unwrap(converter.convertType(svType.getElementType()));
if (!elementType)
return {};
auto sVectorType =
LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
return sVectorType;
}
template <typename OpTy>
class ForwardOperands : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
@@ -70,22 +47,10 @@ public:
}
};
static Optional<Value> addUnrealizedCast(OpBuilder &builder,
ScalableVectorType svType,
ValueRange inputs, Location loc) {
if (inputs.size() != 1 ||
!inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
return Value();
return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
.getResult(0);
}
using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
@@ -114,136 +79,10 @@ using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;
// Load operation is lowered to code that obtains a pointer to the indexed
// element and loads from it.
struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = loadOp.getMemRefType();
if (!isConvertibleAndHasIdentityMaps(type))
return failure();
LLVMTypeConverter converter(loadOp.getContext());
auto resultType = loadOp.result().getType();
LLVM::LLVMPointerType llvmDataTypePtr;
if (resultType.isa<VectorType>()) {
llvmDataTypePtr =
LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
} else if (resultType.isa<ScalableVectorType>()) {
llvmDataTypePtr = LLVM::LLVMPointerType::get(
convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
converter)
.getValue());
}
Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(),
adaptor.index(), rewriter);
Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
loadOp.getLoc(), llvmDataTypePtr, dataPtr);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
return success();
}
};
// Store operation is lowered to code that obtains a pointer to the indexed
// element, and stores the given value to it.
struct ScalableStoreOpLowering
: public ConvertOpToLLVMPattern<ScalableStoreOp> {
using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = storeOp.getMemRefType();
if (!isConvertibleAndHasIdentityMaps(type))
return failure();
LLVMTypeConverter converter(storeOp.getContext());
auto resultType = storeOp.value().getType();
LLVM::LLVMPointerType llvmDataTypePtr;
if (resultType.isa<VectorType>()) {
llvmDataTypePtr =
LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
} else if (resultType.isa<ScalableVectorType>()) {
llvmDataTypePtr = LLVM::LLVMPointerType::get(
convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
converter)
.getValue());
}
Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(),
adaptor.index(), rewriter);
Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
storeOp.getLoc(), llvmDataTypePtr, dataPtr);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
bitCastedPtr);
return success();
}
};
static void
populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
// clang-format off
patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
>(converter);
// clang-format on
}
static void
configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
// clang-format off
target.addIllegalOp<ScalableAddIOp,
ScalableAddFOp,
ScalableSubIOp,
ScalableSubFOp,
ScalableMulIOp,
ScalableMulFOp,
ScalableSDivIOp,
ScalableUDivIOp,
ScalableDivFOp>();
// clang-format on
}
static void
populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
// clang-format off
patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
>(converter);
// clang-format on
}
static void
configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
// clang-format off
target.addIllegalOp<ScalableCmpFOp,
ScalableCmpIOp>();
// clang-format on
}
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// Populate conversion patterns
// Remove any ArmSVE-specific types from function signatures and results.
populateFuncOpTypeConversionPattern(patterns, converter);
converter.addConversion([&converter](ScalableVectorType svType) {
return convertScalableVectorTypeToLLVM(svType, converter);
});
converter.addSourceMaterialization(addUnrealizedCast);
// clang-format off
patterns.add<ForwardOperands<CallOp>,
@@ -254,7 +93,6 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
VectorScaleOpLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
@@ -264,11 +102,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ScalableMaskedSDivIOpLowering,
ScalableMaskedUDivIOpLowering,
ScalableMaskedDivFOpLowering>(converter);
patterns.add<ScalableLoadOpLowering,
ScalableStoreOpLowering>(converter);
// clang-format on
populateBasicSVEArithmeticExportPatterns(converter, patterns);
populateSVEMaskGenerationExportPatterns(converter, patterns);
}
void mlir::configureArmSVELegalizeForExportTarget(
@@ -278,7 +112,6 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
VectorScaleIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
@@ -292,7 +125,6 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
VectorScaleOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
@@ -301,25 +133,6 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFOp,
ScalableMaskedSDivIOp,
ScalableMaskedUDivIOp,
ScalableMaskedDivFOp,
ScalableLoadOp,
ScalableStoreOp>();
ScalableMaskedDivFOp>();
// clang-format on
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
if (type.isa<arm_sve::ScalableVectorType>())
return true;
return false;
};
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
return !hasScalableVectorType(op.getType().getInputs()) &&
!hasScalableVectorType(op.getType().getResults());
});
target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
[hasScalableVectorType](Operation *op) {
return !hasScalableVectorType(op->getOperandTypes()) &&
!hasScalableVectorType(op->getResultTypes());
});
configureBasicSVEArithmeticLegalizations(target);
configureSVEMaskGenerationLegalizations(target);
}

View File

@@ -155,12 +155,14 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
return parser.emitError(trailingTypeLoc,
"expected LLVM dialect-compatible type");
if (LLVM::isCompatibleVectorType(type)) {
if (type.isa<LLVM::LLVMScalableVectorType>()) {
resultType = LLVM::LLVMScalableVectorType::get(
resultType, LLVM::getVectorNumElements(type).getKnownMinValue());
if (LLVM::isScalableVectorType(type)) {
resultType = LLVM::getVectorType(
resultType, LLVM::getVectorNumElements(type).getKnownMinValue(),
/*isScalable=*/true);
} else {
resultType = LLVM::getFixedVectorType(
resultType, LLVM::getVectorNumElements(type).getFixedValue());
resultType = LLVM::getVectorType(
resultType, LLVM::getVectorNumElements(type).getFixedValue(),
/*isScalable=*/false);
}
}

View File

@@ -775,7 +775,12 @@ Type mlir::LLVM::getVectorElementType(Type type) {
llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
return llvm::TypeSwitch<Type, llvm::ElementCount>(type)
.Case<LLVMFixedVectorType, VectorType>([](auto ty) {
.Case([](VectorType ty) {
if (ty.isScalable())
return llvm::ElementCount::getScalable(ty.getNumElements());
return llvm::ElementCount::getFixed(ty.getNumElements());
})
.Case([](LLVMFixedVectorType ty) {
return llvm::ElementCount::getFixed(ty.getNumElements());
})
.Case([](LLVMScalableVectorType ty) {
@@ -786,6 +791,31 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
});
}
bool mlir::LLVM::isScalableVectorType(Type vectorType) {
assert(
(vectorType
.isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>()) &&
"expected LLVM-compatible vector type");
return !vectorType.isa<LLVMFixedVectorType>() &&
(vectorType.isa<LLVMScalableVectorType>() ||
vectorType.cast<VectorType>().isScalable());
}
Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
bool isScalable) {
bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
bool useBuiltIn = VectorType::isValidElementType(elementType);
(void)useBuiltIn;
assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
"to be either builtin or LLVM dialect type");
if (useLLVM) {
if (isScalable)
return LLVMScalableVectorType::get(elementType, numElements);
return LLVMFixedVectorType::get(elementType, numElements);
}
return VectorType::get(numElements, elementType, (unsigned)isScalable);
}
Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
bool useBuiltIn = VectorType::isValidElementType(elementType);
@@ -797,6 +827,18 @@ Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
return VectorType::get(numElements, elementType);
}
Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType);
bool useBuiltIn = VectorType::isValidElementType(elementType);
(void)useBuiltIn;
assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector "
"type to be either builtin or LLVM dialect "
"type");
if (useLLVM)
return LLVMScalableVectorType::get(elementType, numElements);
return VectorType::get(numElements, elementType, /*numScalableDims=*/1);
}
llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
assert(isCompatibleType(type) &&
"expected a type compatible with the LLVM dialect");

View File

@@ -515,7 +515,8 @@ static Type getI1SameShape(Type type) {
if (type.isa<UnrankedTensorType>())
return UnrankedTensorType::get(i1Type);
if (auto vectorType = type.dyn_cast<VectorType>())
return VectorType::get(vectorType.getShape(), i1Type);
return VectorType::get(vectorType.getShape(), i1Type,
vectorType.getNumScalableDims());
return i1Type;
}

View File

@@ -1954,8 +1954,19 @@ void AsmPrinter::Impl::printType(Type type) {
})
.Case<VectorType>([&](VectorType vectorTy) {
os << "vector<";
for (int64_t dim : vectorTy.getShape())
os << dim << 'x';
auto vShape = vectorTy.getShape();
unsigned lastDim = vShape.size();
unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims();
unsigned dimIdx = 0;
for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++)
os << vShape[dimIdx] << 'x';
if (vectorTy.isScalable()) {
os << '[';
unsigned secondToLastDim = lastDim - 1;
for (; dimIdx < secondToLastDim; dimIdx++)
os << vShape[dimIdx] << 'x';
os << vShape[dimIdx] << "]x";
}
printType(vectorTy.getElementType());
os << '>';
})

View File

@@ -1165,8 +1165,9 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
else if (inType.isa<UnrankedTensorType>())
newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
else if (inType.isa<VectorType>())
newArrayType = VectorType::get(inType.getShape(), newElementType);
else if (auto vType = inType.dyn_cast<VectorType>())
newArrayType = VectorType::get(vType.getShape(), newElementType,
vType.getNumScalableDims());
else
assert(newArrayType && "Unhandled tensor type");

View File

@@ -294,8 +294,8 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
if (isa<TensorType>())
return RankedTensorType::get(shape, elementType);
if (isa<VectorType>())
return VectorType::get(shape, elementType);
if (auto vecTy = dyn_cast<VectorType>())
return VectorType::get(shape, elementType, vecTy.getNumScalableDims());
llvm_unreachable("Unhandled ShapedType clone case");
}
@@ -317,8 +317,8 @@ ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
if (isa<TensorType>())
return RankedTensorType::get(shape, getElementType());
if (isa<VectorType>())
return VectorType::get(shape, getElementType());
if (auto vecTy = dyn_cast<VectorType>())
return VectorType::get(shape, getElementType(), vecTy.getNumScalableDims());
llvm_unreachable("Unhandled ShapedType clone case");
}
@@ -340,8 +340,8 @@ ShapedType ShapedType::clone(Type elementType) {
return UnrankedTensorType::get(elementType);
}
if (isa<VectorType>())
return VectorType::get(getShape(), elementType);
if (auto vecTy = dyn_cast<VectorType>())
return VectorType::get(getShape(), elementType, vecTy.getNumScalableDims());
llvm_unreachable("Unhandled ShapedType clone hit");
}
@@ -441,7 +441,8 @@ bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
//===----------------------------------------------------------------------===//
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType) {
ArrayRef<int64_t> shape, Type elementType,
unsigned numScalableDims) {
if (!isValidElementType(elementType))
return emitError()
<< "vector elements must be int/index/float type but got "
@@ -460,10 +461,10 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
return VectorType();
if (auto et = getElementType().dyn_cast<IntegerType>())
if (auto scaledEt = et.scaleElementBitwidth(scale))
return VectorType::get(getShape(), scaledEt);
return VectorType::get(getShape(), scaledEt, getNumScalableDims());
if (auto et = getElementType().dyn_cast<FloatType>())
if (auto scaledEt = et.scaleElementBitwidth(scale))
return VectorType::get(getShape(), scaledEt);
return VectorType::get(getShape(), scaledEt, getNumScalableDims());
return VectorType();
}

View File

@@ -196,8 +196,11 @@ public:
/// Parse a vector type.
VectorType parseVectorType();
ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
unsigned &numScalableDims);
ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic = true);
ParseResult parseIntegerInDimensionList(int64_t &value);
ParseResult parseXInDimensionList();
/// Parse strided layout specification.

View File

@@ -13,6 +13,7 @@
#include "Parser.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TensorEncoding.h"
using namespace mlir;
@@ -442,8 +443,9 @@ Type Parser::parseTupleType() {
/// Parse a vector type.
///
/// vector-type ::= `vector` `<` static-dimension-list type `>`
/// static-dimension-list ::= (decimal-literal `x`)*
/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
///
VectorType Parser::parseVectorType() {
consumeToken(Token::kw_vector);
@@ -452,7 +454,8 @@ VectorType Parser::parseVectorType() {
return nullptr;
SmallVector<int64_t, 4> dimensions;
if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
unsigned numScalableDims;
if (parseVectorDimensionList(dimensions, numScalableDims))
return nullptr;
if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
return emitError(getToken().getLoc(),
@@ -464,11 +467,59 @@ VectorType Parser::parseVectorType() {
auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
return VectorType::get(dimensions, elementType);
return VectorType::get(dimensions, elementType, numScalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list,
/// and returns the number of scalable dimensions in `numScalableDims`.
///
/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
///
ParseResult
Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
unsigned &numScalableDims) {
numScalableDims = 0;
// If there is a set of fixed-length dimensions, consume it
while (getToken().is(Token::integer)) {
int64_t value;
if (parseIntegerInDimensionList(value))
return failure();
dimensions.push_back(value);
// Make sure we have an 'x' or something like 'xbf32'.
if (parseXInDimensionList())
return failure();
}
// If there is a set of scalable dimensions, consume it
if (consumeIf(Token::l_square)) {
while (getToken().is(Token::integer)) {
int64_t value;
if (parseIntegerInDimensionList(value))
return failure();
dimensions.push_back(value);
numScalableDims++;
// Check if we have reached the end of the scalable dimension list
if (consumeIf(Token::r_square)) {
// Make sure we have something like 'xbf32'.
if (parseXInDimensionList())
return failure();
return success();
}
// Make sure we have an 'x'
if (parseXInDimensionList())
return failure();
}
// If we make it here, we've finished parsing the dimension list
// without finding ']' closing the set of scalable dimensions
return emitError("missing ']' closing set of scalable dimensions");
}
return success();
}
/// Parse a dimension list of a tensor or memref type. This populates the
@@ -490,28 +541,11 @@ Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
return emitError("expected static shape");
dimensions.push_back(-1);
} else {
// Hexadecimal integer literals (starting with `0x`) are not allowed in
// aggregate type declarations. Therefore, `0xf32` should be processed as
// a sequence of separate elements `0`, `x`, `f32`.
if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
// We can get here only if the token is an integer literal. Hexadecimal
// integer literals can only start with `0x` (`1x` wouldn't lex as a
// literal, just `1` would, at which point we don't get into this
// branch).
assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
dimensions.push_back(0);
state.lex.resetPointer(getTokenSpelling().data() + 1);
consumeToken();
} else {
// Make sure this integer value is in bound and valid.
Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
if (!dimension || *dimension > std::numeric_limits<int64_t>::max())
return emitError("invalid dimension");
dimensions.push_back((int64_t)dimension.getValue());
consumeToken(Token::integer);
}
int64_t value;
if (parseIntegerInDimensionList(value))
return failure();
dimensions.push_back(value);
}
// Make sure we have an 'x' or something like 'xbf32'.
if (parseXInDimensionList())
return failure();
@@ -520,6 +554,30 @@ Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
return success();
}
ParseResult Parser::parseIntegerInDimensionList(int64_t &value) {
// Hexadecimal integer literals (starting with `0x`) are not allowed in
// aggregate type declarations. Therefore, `0xf32` should be processed as
// a sequence of separate elements `0`, `x`, `f32`.
if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
// We can get here only if the token is an integer literal. Hexadecimal
// integer literals can only start with `0x` (`1x` wouldn't lex as a
// literal, just `1` would, at which point we don't get into this
// branch).
assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
value = 0;
state.lex.resetPointer(getTokenSpelling().data() + 1);
consumeToken();
} else {
// Make sure this integer value is in bound and valid.
Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
if (!dimension || *dimension > std::numeric_limits<int64_t>::max())
return emitError("invalid dimension");
value = (int64_t)dimension.getValue();
consumeToken(Token::integer);
}
return success();
}
/// Parse an 'x' token in a dimension list, handling the case where the x is
/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
/// token.

View File

@@ -242,10 +242,14 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
elementType = arrayTy->getElementType();
numElements = arrayTy->getNumElements();
} else if (auto fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
elementType = fVectorTy->getElementType();
numElements = fVectorTy->getNumElements();
} else if (auto sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
elementType = sVectorTy->getElementType();
numElements = sVectorTy->getMinNumElements();
} else {
auto *vectorTy = cast<llvm::FixedVectorType>(llvmType);
elementType = vectorTy->getElementType();
numElements = vectorTy->getNumElements();
llvm_unreachable("unrecognized constant vector type");
}
// Splat value is a scalar. Extract it only if the element type is not
// another sequence type. The recursion terminates because each step removes

View File

@@ -137,6 +137,9 @@ private:
llvm::Type *translate(VectorType type) {
assert(LLVM::isCompatibleVectorType(type) &&
"expected compatible with LLVM vector type");
if (type.isScalable())
return llvm::ScalableVectorType::get(translateType(type.getElementType()),
type.getNumElements());
return llvm::FixedVectorType::get(translateType(type.getElementType()),
type.getNumElements());
}

View File

@@ -19,6 +19,12 @@ func @test_addi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_addi_scalable_vector
func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.addi %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_subi
func @test_subi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.subi %arg0, %arg1 : i64
@@ -37,6 +43,12 @@ func @test_subi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_subi_scalable_vector
func @test_subi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.subi %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_muli
func @test_muli(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.muli %arg0, %arg1 : i64
@@ -55,6 +67,12 @@ func @test_muli_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_muli_scalable_vector
func @test_muli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.muli %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_divui
func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.divui %arg0, %arg1 : i64
@@ -73,6 +91,12 @@ func @test_divui_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_divui_scalable_vector
func @test_divui_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.divui %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_divsi
func @test_divsi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.divsi %arg0, %arg1 : i64
@@ -91,6 +115,12 @@ func @test_divsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_divsi_scalable_vector
func @test_divsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.divsi %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_remui
func @test_remui(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.remui %arg0, %arg1 : i64
@@ -109,6 +139,12 @@ func @test_remui_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_remui_scalable_vector
func @test_remui_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.remui %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_remsi
func @test_remsi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.remsi %arg0, %arg1 : i64
@@ -127,6 +163,12 @@ func @test_remsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_remsi_scalable_vector
func @test_remsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.remsi %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_andi
func @test_andi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.andi %arg0, %arg1 : i64
@@ -145,6 +187,12 @@ func @test_andi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_andi_scalable_vector
func @test_andi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.andi %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_ori
func @test_ori(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.ori %arg0, %arg1 : i64
@@ -163,6 +211,12 @@ func @test_ori_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8x
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_ori_scalable_vector
func @test_ori_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.ori %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_xori
func @test_xori(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.xori %arg0, %arg1 : i64
@@ -181,6 +235,12 @@ func @test_xori_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_xori_scalable_vector
func @test_xori_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.xori %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_ceildivsi
func @test_ceildivsi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.ceildivsi %arg0, %arg1 : i64
@@ -199,6 +259,12 @@ func @test_ceildivsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vec
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_ceildivsi_scalable_vector
func @test_ceildivsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.ceildivsi %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_floordivsi
func @test_floordivsi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.floordivsi %arg0, %arg1 : i64
@@ -217,6 +283,12 @@ func @test_floordivsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> ve
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_floordivsi_scalable_vector
func @test_floordivsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.floordivsi %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_shli
func @test_shli(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.shli %arg0, %arg1 : i64
@@ -235,6 +307,12 @@ func @test_shli_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_shli_scalable_vector
func @test_shli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.shli %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_shrui
func @test_shrui(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.shrui %arg0, %arg1 : i64
@@ -253,6 +331,12 @@ func @test_shrui_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_shrui_scalable_vector
func @test_shrui_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.shrui %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_shrsi
func @test_shrsi(%arg0 : i64, %arg1 : i64) -> i64 {
%0 = arith.shrsi %arg0, %arg1 : i64
@@ -271,6 +355,12 @@ func @test_shrsi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_shrsi_scalable_vector
func @test_shrsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
%0 = arith.shrsi %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_negf
func @test_negf(%arg0 : f64) -> f64 {
%0 = arith.negf %arg0 : f64
@@ -289,6 +379,12 @@ func @test_negf_vector(%arg0 : vector<8xf64>) -> vector<8xf64> {
return %0 : vector<8xf64>
}
// CHECK-LABEL: test_negf_scalable_vector
func @test_negf_scalable_vector(%arg0 : vector<[8]xf64>) -> vector<[8]xf64> {
%0 = arith.negf %arg0 : vector<[8]xf64>
return %0 : vector<[8]xf64>
}
// CHECK-LABEL: test_addf
func @test_addf(%arg0 : f64, %arg1 : f64) -> f64 {
%0 = arith.addf %arg0, %arg1 : f64
@@ -307,6 +403,12 @@ func @test_addf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
return %0 : vector<8xf64>
}
// CHECK-LABEL: test_addf_scalable_vector
func @test_addf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
%0 = arith.addf %arg0, %arg1 : vector<[8]xf64>
return %0 : vector<[8]xf64>
}
// CHECK-LABEL: test_subf
func @test_subf(%arg0 : f64, %arg1 : f64) -> f64 {
%0 = arith.subf %arg0, %arg1 : f64
@@ -325,6 +427,12 @@ func @test_subf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
return %0 : vector<8xf64>
}
// CHECK-LABEL: test_subf_scalable_vector
func @test_subf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
%0 = arith.subf %arg0, %arg1 : vector<[8]xf64>
return %0 : vector<[8]xf64>
}
// CHECK-LABEL: test_mulf
func @test_mulf(%arg0 : f64, %arg1 : f64) -> f64 {
%0 = arith.mulf %arg0, %arg1 : f64
@@ -343,6 +451,12 @@ func @test_mulf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
return %0 : vector<8xf64>
}
// CHECK-LABEL: test_mulf_scalable_vector
func @test_mulf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
%0 = arith.mulf %arg0, %arg1 : vector<[8]xf64>
return %0 : vector<[8]xf64>
}
// CHECK-LABEL: test_divf
func @test_divf(%arg0 : f64, %arg1 : f64) -> f64 {
%0 = arith.divf %arg0, %arg1 : f64
@@ -361,6 +475,12 @@ func @test_divf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
return %0 : vector<8xf64>
}
// CHECK-LABEL: test_divf_scalable_vector
func @test_divf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
%0 = arith.divf %arg0, %arg1 : vector<[8]xf64>
return %0 : vector<[8]xf64>
}
// CHECK-LABEL: test_remf
func @test_remf(%arg0 : f64, %arg1 : f64) -> f64 {
%0 = arith.remf %arg0, %arg1 : f64
@@ -379,6 +499,12 @@ func @test_remf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
return %0 : vector<8xf64>
}
// CHECK-LABEL: test_remf_scalable_vector
func @test_remf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> {
%0 = arith.remf %arg0, %arg1 : vector<[8]xf64>
return %0 : vector<[8]xf64>
}
// CHECK-LABEL: test_extui
func @test_extui(%arg0 : i32) -> i64 {
%0 = arith.extui %arg0 : i32 to i64
@@ -397,6 +523,12 @@ func @test_extui_vector(%arg0 : vector<8xi32>) -> vector<8xi64> {
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_extui_scalable_vector
func @test_extui_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xi64> {
%0 = arith.extui %arg0 : vector<[8]xi32> to vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_extsi
func @test_extsi(%arg0 : i32) -> i64 {
%0 = arith.extsi %arg0 : i32 to i64
@@ -415,6 +547,12 @@ func @test_extsi_vector(%arg0 : vector<8xi32>) -> vector<8xi64> {
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_extsi_scalable_vector
func @test_extsi_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xi64> {
%0 = arith.extsi %arg0 : vector<[8]xi32> to vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_extf
func @test_extf(%arg0 : f32) -> f64 {
%0 = arith.extf %arg0 : f32 to f64
@@ -433,6 +571,12 @@ func @test_extf_vector(%arg0 : vector<8xf32>) -> vector<8xf64> {
return %0 : vector<8xf64>
}
// CHECK-LABEL: test_extf_scalable_vector
func @test_extf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xf64> {
%0 = arith.extf %arg0 : vector<[8]xf32> to vector<[8]xf64>
return %0 : vector<[8]xf64>
}
// CHECK-LABEL: test_trunci
func @test_trunci(%arg0 : i32) -> i16 {
%0 = arith.trunci %arg0 : i32 to i16
@@ -451,6 +595,12 @@ func @test_trunci_vector(%arg0 : vector<8xi32>) -> vector<8xi16> {
return %0 : vector<8xi16>
}
// CHECK-LABEL: test_trunci_scalable_vector
func @test_trunci_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xi16> {
%0 = arith.trunci %arg0 : vector<[8]xi32> to vector<[8]xi16>
return %0 : vector<[8]xi16>
}
// CHECK-LABEL: test_truncf
func @test_truncf(%arg0 : f32) -> bf16 {
%0 = arith.truncf %arg0 : f32 to bf16
@@ -469,6 +619,12 @@ func @test_truncf_vector(%arg0 : vector<8xf32>) -> vector<8xbf16> {
return %0 : vector<8xbf16>
}
// CHECK-LABEL: test_truncf_scalable_vector
func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf16> {
%0 = arith.truncf %arg0 : vector<[8]xf32> to vector<[8]xbf16>
return %0 : vector<[8]xbf16>
}
// CHECK-LABEL: test_uitofp
func @test_uitofp(%arg0 : i32) -> f32 {
%0 = arith.uitofp %arg0 : i32 to f32
@@ -487,6 +643,12 @@ func @test_uitofp_vector(%arg0 : vector<8xi32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}
// CHECK-LABEL: test_uitofp_scalable_vector
func @test_uitofp_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xf32> {
%0 = arith.uitofp %arg0 : vector<[8]xi32> to vector<[8]xf32>
return %0 : vector<[8]xf32>
}
// CHECK-LABEL: test_sitofp
func @test_sitofp(%arg0 : i16) -> f64 {
%0 = arith.sitofp %arg0 : i16 to f64
@@ -505,6 +667,12 @@ func @test_sitofp_vector(%arg0 : vector<8xi16>) -> vector<8xf64> {
return %0 : vector<8xf64>
}
// CHECK-LABEL: test_sitofp_scalable_vector
func @test_sitofp_scalable_vector(%arg0 : vector<[8]xi16>) -> vector<[8]xf64> {
%0 = arith.sitofp %arg0 : vector<[8]xi16> to vector<[8]xf64>
return %0 : vector<[8]xf64>
}
// CHECK-LABEL: test_fptoui
func @test_fptoui(%arg0 : bf16) -> i8 {
%0 = arith.fptoui %arg0 : bf16 to i8
@@ -523,6 +691,12 @@ func @test_fptoui_vector(%arg0 : vector<8xbf16>) -> vector<8xi8> {
return %0 : vector<8xi8>
}
// CHECK-LABEL: test_fptoui_scalable_vector
func @test_fptoui_scalable_vector(%arg0 : vector<[8]xbf16>) -> vector<[8]xi8> {
%0 = arith.fptoui %arg0 : vector<[8]xbf16> to vector<[8]xi8>
return %0 : vector<[8]xi8>
}
// CHECK-LABEL: test_fptosi
func @test_fptosi(%arg0 : f64) -> i64 {
%0 = arith.fptosi %arg0 : f64 to i64
@@ -541,6 +715,12 @@ func @test_fptosi_vector(%arg0 : vector<8xf64>) -> vector<8xi64> {
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_fptosi_scalable_vector
func @test_fptosi_scalable_vector(%arg0 : vector<[8]xf64>) -> vector<[8]xi64> {
%0 = arith.fptosi %arg0 : vector<[8]xf64> to vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_index_cast0
func @test_index_cast0(%arg0 : i32) -> index {
%0 = arith.index_cast %arg0 : i32 to index
@@ -559,6 +739,12 @@ func @test_index_cast_vector0(%arg0 : vector<8xi32>) -> vector<8xindex> {
return %0 : vector<8xindex>
}
// CHECK-LABEL: test_index_cast_scalable_vector0
func @test_index_cast_scalable_vector0(%arg0 : vector<[8]xi32>) -> vector<[8]xindex> {
%0 = arith.index_cast %arg0 : vector<[8]xi32> to vector<[8]xindex>
return %0 : vector<[8]xindex>
}
// CHECK-LABEL: test_index_cast1
func @test_index_cast1(%arg0 : index) -> i64 {
%0 = arith.index_cast %arg0 : index to i64
@@ -577,6 +763,12 @@ func @test_index_cast_vector1(%arg0 : vector<8xindex>) -> vector<8xi64> {
return %0 : vector<8xi64>
}
// CHECK-LABEL: test_index_cast_scalable_vector1
func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector<[8]xi64> {
%0 = arith.index_cast %arg0 : vector<[8]xindex> to vector<[8]xi64>
return %0 : vector<[8]xi64>
}
// CHECK-LABEL: test_bitcast0
func @test_bitcast0(%arg0 : i64) -> f64 {
%0 = arith.bitcast %arg0 : i64 to f64
@@ -595,6 +787,12 @@ func @test_bitcast_vector0(%arg0 : vector<8xi64>) -> vector<8xf64> {
return %0 : vector<8xf64>
}
// CHECK-LABEL: test_bitcast_scalable_vector0
func @test_bitcast_scalable_vector0(%arg0 : vector<[8]xi64>) -> vector<[8]xf64> {
%0 = arith.bitcast %arg0 : vector<[8]xi64> to vector<[8]xf64>
return %0 : vector<[8]xf64>
}
// CHECK-LABEL: test_bitcast1
func @test_bitcast1(%arg0 : f32) -> i32 {
%0 = arith.bitcast %arg0 : f32 to i32
@@ -613,6 +811,12 @@ func @test_bitcast_vector1(%arg0 : vector<8xf32>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
// CHECK-LABEL: test_bitcast_scalable_vector1
func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]xi32> {
%0 = arith.bitcast %arg0 : vector<[8]xf32> to vector<[8]xi32>
return %0 : vector<[8]xi32>
}
// CHECK-LABEL: test_cmpi
func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
%0 = arith.cmpi ne, %arg0, %arg1 : i64
@@ -631,6 +835,12 @@ func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
return %0 : vector<8xi1>
}
// CHECK-LABEL: test_cmpi_scalable_vector
func @test_cmpi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi1> {
%0 = arith.cmpi ult, %arg0, %arg1 : vector<[8]xi64>
return %0 : vector<[8]xi1>
}
// CHECK-LABEL: test_cmpi_vector_0d
func @test_cmpi_vector_0d(%arg0 : vector<i64>, %arg1 : vector<i64>) -> vector<i1> {
%0 = arith.cmpi ult, %arg0, %arg1 : vector<i64>
@@ -655,6 +865,12 @@ func @test_cmpf_vector(%arg0 : vector<8xf64>, %arg1 : vector<8xf64>) -> vector<8
return %0 : vector<8xi1>
}
// CHECK-LABEL: test_cmpf_scalable_vector
func @test_cmpf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xi1> {
%0 = arith.cmpf ult, %arg0, %arg1 : vector<[8]xf64>
return %0 : vector<[8]xi1>
}
// CHECK-LABEL: test_index_cast
func @test_index_cast(%arg0 : index) -> i64 {
%0 = arith.index_cast %arg0 : index to i64
@@ -713,9 +929,11 @@ func @test_constant() -> () {
// CHECK-LABEL: func @maximum
func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
%f1: f32, %f2: f32,
%i1: i32, %i2: i32) {
%max_vector = arith.maxf %v1, %v2 : vector<4xf32>
%max_scalable_vector = arith.maxf %sv1, %sv2 : vector<[4]xf32>
%max_float = arith.maxf %f1, %f2 : f32
%max_signed = arith.maxsi %i1, %i2 : i32
%max_unsigned = arith.maxui %i1, %i2 : i32
@@ -724,9 +942,11 @@ func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
// CHECK-LABEL: func @minimum
func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
%f1: f32, %f2: f32,
%i1: i32, %i2: i32) {
%min_vector = arith.minf %v1, %v2 : vector<4xf32>
%min_scalable_vector = arith.minf %sv1, %sv2 : vector<[4]xf32>
%min_float = arith.minf %f1, %f2 : f32
%min_signed = arith.minsi %i1, %i2 : i32
%min_unsigned = arith.minui %i1, %i2 : i32

View File

@@ -1,168 +1,118 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-std-to-llvm | mlir-opt | FileCheck %s
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-std-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s
func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
func @arm_sve_sdot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: arm_sve.intr.sdot
%0 = arm_sve.sdot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
func @arm_sve_smmla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: arm_sve.intr.smmla
%0 = arm_sve.smmla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
func @arm_sve_udot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: arm_sve.intr.udot
%0 = arm_sve.udot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
func @arm_sve_ummla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: arm_sve.intr.ummla
%0 = arm_sve.ummla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
%b: !arm_sve.vector<4xi32>,
%c: !arm_sve.vector<4xi32>,
%d: !arm_sve.vector<4xi32>,
%e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: llvm.mul {{.*}}: !llvm.vec<? x 4 x i32>
%0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
// CHECK: llvm.add {{.*}}: !llvm.vec<? x 4 x i32>
%1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
// CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
%2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32>
// CHECK: llvm.sdiv {{.*}}: !llvm.vec<? x 4 x i32>
%3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32>
// CHECK: llvm.udiv {{.*}}: !llvm.vec<? x 4 x i32>
%4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32>
return %4 : !arm_sve.vector<4xi32>
func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
%d: vector<[4]xi32>,
%e: vector<[4]xi32>,
%mask: vector<[4]xi1>
) -> vector<[4]xi32> {
// CHECK: arm_sve.intr.add{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
%0 = arm_sve.masked.addi %mask, %a, %b : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: arm_sve.intr.sub{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
%1 = arm_sve.masked.subi %mask, %0, %c : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: arm_sve.intr.mul{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
%2 = arm_sve.masked.muli %mask, %1, %d : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: arm_sve.intr.sdiv{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
%3 = arm_sve.masked.divi_signed %mask, %2, %e : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: arm_sve.intr.udiv{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
%4 = arm_sve.masked.divi_unsigned %mask, %3, %e : vector<[4]xi1>,
vector<[4]xi32>
return %4 : vector<[4]xi32>
}
func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
%b: !arm_sve.vector<4xf32>,
%c: !arm_sve.vector<4xf32>,
%d: !arm_sve.vector<4xf32>,
%e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
// CHECK: llvm.fmul {{.*}}: !llvm.vec<? x 4 x f32>
%0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
// CHECK: llvm.fadd {{.*}}: !llvm.vec<? x 4 x f32>
%1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
// CHECK: llvm.fsub {{.*}}: !llvm.vec<? x 4 x f32>
%2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32>
// CHECK: llvm.fdiv {{.*}}: !llvm.vec<? x 4 x f32>
%3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32>
return %3 : !arm_sve.vector<4xf32>
func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
%b: vector<[4]xf32>,
%c: vector<[4]xf32>,
%d: vector<[4]xf32>,
%e: vector<[4]xf32>,
%mask: vector<[4]xi1>
) -> vector<[4]xf32> {
// CHECK: arm_sve.intr.fadd{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
%0 = arm_sve.masked.addf %mask, %a, %b : vector<[4]xi1>,
vector<[4]xf32>
// CHECK: arm_sve.intr.fsub{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
%1 = arm_sve.masked.subf %mask, %0, %c : vector<[4]xi1>,
vector<[4]xf32>
// CHECK: arm_sve.intr.fmul{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
%2 = arm_sve.masked.mulf %mask, %1, %d : vector<[4]xi1>,
vector<[4]xf32>
// CHECK: arm_sve.intr.fdiv{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32>
%3 = arm_sve.masked.divf %mask, %2, %e : vector<[4]xi1>,
vector<[4]xf32>
return %3 : vector<[4]xf32>
}
func @arm_sve_arithi_masked(%a: !arm_sve.vector<4xi32>,
%b: !arm_sve.vector<4xi32>,
%c: !arm_sve.vector<4xi32>,
%d: !arm_sve.vector<4xi32>,
%e: !arm_sve.vector<4xi32>,
%mask: !arm_sve.vector<4xi1>
) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.intr.add{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
%0 = arm_sve.masked.addi %mask, %a, %b : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
// CHECK: arm_sve.intr.sub{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
%1 = arm_sve.masked.subi %mask, %0, %c : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
// CHECK: arm_sve.intr.mul{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
%2 = arm_sve.masked.muli %mask, %1, %d : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
// CHECK: arm_sve.intr.sdiv{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
%3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
// CHECK: arm_sve.intr.udiv{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
%4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
return %4 : !arm_sve.vector<4xi32>
}
func @arm_sve_arithf_masked(%a: !arm_sve.vector<4xf32>,
%b: !arm_sve.vector<4xf32>,
%c: !arm_sve.vector<4xf32>,
%d: !arm_sve.vector<4xf32>,
%e: !arm_sve.vector<4xf32>,
%mask: !arm_sve.vector<4xi1>
) -> !arm_sve.vector<4xf32> {
// CHECK: arm_sve.intr.fadd{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
%0 = arm_sve.masked.addf %mask, %a, %b : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xf32>
// CHECK: arm_sve.intr.fsub{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
%1 = arm_sve.masked.subf %mask, %0, %c : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xf32>
// CHECK: arm_sve.intr.fmul{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
%2 = arm_sve.masked.mulf %mask, %1, %d : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xf32>
// CHECK: arm_sve.intr.fdiv{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x f32>, !llvm.vec<? x 4 x f32>) -> !llvm.vec<? x 4 x f32>
%3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xf32>
return %3 : !arm_sve.vector<4xf32>
}
func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>,
%b: !arm_sve.vector<4xf32>)
-> !arm_sve.vector<4xi1> {
// CHECK: llvm.fcmp "oeq" {{.*}}: !llvm.vec<? x 4 x f32>
%0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32>
return %0 : !arm_sve.vector<4xi1>
}
func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
%b: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi1> {
// CHECK: llvm.icmp "uge" {{.*}}: !llvm.vec<? x 4 x i32>
%0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi1>
}
func @arm_sve_abs_diff(%a: !arm_sve.vector<4xi32>,
%b: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
// CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
%z = arm_sve.subi %a, %a : !arm_sve.vector<4xi32>
// CHECK: llvm.icmp "sge" {{.*}}: !llvm.vec<? x 4 x i32>
%agb = arm_sve.cmpi sge, %a, %b : !arm_sve.vector<4xi32>
// CHECK: llvm.icmp "slt" {{.*}}: !llvm.vec<? x 4 x i32>
%bga = arm_sve.cmpi slt, %a, %b : !arm_sve.vector<4xi32>
// CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
%0 = arm_sve.masked.subi %agb, %a, %b : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
// CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
%1 = arm_sve.masked.subi %bga, %b, %a : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
// CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
%2 = arm_sve.masked.addi %agb, %z, %0 : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
// CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
%3 = arm_sve.masked.addi %bga, %2, %1 : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
return %3 : !arm_sve.vector<4xi32>
func @arm_sve_abs_diff(%a: vector<[4]xi32>,
%b: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: llvm.mlir.constant(dense<0> : vector<[4]xi32>) : vector<[4]xi32>
%z = arith.subi %a, %a : vector<[4]xi32>
// CHECK: llvm.icmp "sge" {{.*}}: vector<[4]xi32>
%agb = arith.cmpi sge, %a, %b : vector<[4]xi32>
// CHECK: llvm.icmp "slt" {{.*}}: vector<[4]xi32>
%bga = arith.cmpi slt, %a, %b : vector<[4]xi32>
// CHECK: "arm_sve.intr.sub"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
%0 = arm_sve.masked.subi %agb, %a, %b : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: "arm_sve.intr.sub"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
%1 = arm_sve.masked.subi %bga, %b, %a : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: "arm_sve.intr.add"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
%2 = arm_sve.masked.addi %agb, %z, %0 : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: "arm_sve.intr.add"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32>
%3 = arm_sve.masked.addi %bga, %2, %1 : vector<[4]xi1>,
vector<[4]xi32>
return %3 : vector<[4]xi32>
}
func @get_vector_scale() -> index {
// CHECK: arm_sve.vscale
%0 = arm_sve.vector_scale : index
// CHECK: llvm.intr.vscale
%0 = vector.vscale
return %0 : index
}

View File

@@ -1,28 +0,0 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s
// CHECK: memcopy([[SRC:%arg[0-9]+]]: memref<?xf32>, [[DST:%arg[0-9]+]]
func @memcopy(%src : memref<?xf32>, %dst : memref<?xf32>, %size : index) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%vs = arm_sve.vector_scale : index
%step = arith.muli %c4, %vs : index
// CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref<?xf32> to !llvm.struct<(ptr<f32>
// CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref<?xf32> to !llvm.struct<(ptr<f32>
// CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}}
scf.for %i0 = %c0 to %size step %step {
// CHECK: [[SRCIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64
// CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr<f32>
// CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
// CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
%0 = arm_sve.load %src[%i0] : !arm_sve.vector<4xf32> from memref<?xf32>
// CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr<f32>
// CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
// CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
arm_sve.store %0, %dst[%i0] : !arm_sve.vector<4xf32> to memref<?xf32>
}
return
}

View File

@@ -1,137 +1,84 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.sdot {{.*}}: <16xi8> to <4xi32
func @arm_sve_sdot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>) -> vector<[4]xi32> {
// CHECK: arm_sve.sdot {{.*}}: vector<[16]xi8> to vector<[4]xi32
%0 = arm_sve.sdot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.smmla {{.*}}: <16xi8> to <4xi3
func @arm_sve_smmla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>) -> vector<[4]xi32> {
// CHECK: arm_sve.smmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
%0 = arm_sve.smmla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.udot {{.*}}: <16xi8> to <4xi32
func @arm_sve_udot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>) -> vector<[4]xi32> {
// CHECK: arm_sve.udot {{.*}}: vector<[16]xi8> to vector<[4]xi32
%0 = arm_sve.udot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.ummla {{.*}}: <16xi8> to <4xi3
func @arm_sve_ummla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>) -> vector<[4]xi32> {
// CHECK: arm_sve.ummla {{.*}}: vector<[16]xi8> to vector<[4]xi3
%0 = arm_sve.ummla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
vector<[16]xi8> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}
func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
%b: !arm_sve.vector<4xi32>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32>
%0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
// CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32>
%1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
return %1 : !arm_sve.vector<4xi32>
}
func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
%b: !arm_sve.vector<4xf32>,
%c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
// CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32>
%0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
// CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32>
%1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
return %1 : !arm_sve.vector<4xf32>
}
func @arm_sve_masked_arithi(%a: !arm_sve.vector<4xi32>,
%b: !arm_sve.vector<4xi32>,
%c: !arm_sve.vector<4xi32>,
%d: !arm_sve.vector<4xi32>,
%e: !arm_sve.vector<4xi32>,
%mask: !arm_sve.vector<4xi1>)
-> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.masked.muli {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
%0 = arm_sve.masked.muli %mask, %a, %b : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
// CHECK: arm_sve.masked.addi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
%1 = arm_sve.masked.addi %mask, %0, %c : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
// CHECK: arm_sve.masked.subi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
%2 = arm_sve.masked.subi %mask, %1, %d : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
%d: vector<[4]xi32>,
%e: vector<[4]xi32>,
%mask: vector<[4]xi1>)
-> vector<[4]xi32> {
// CHECK: arm_sve.masked.muli {{.*}}: vector<[4]xi1>, vector<
%0 = arm_sve.masked.muli %mask, %a, %b : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: arm_sve.masked.addi {{.*}}: vector<[4]xi1>, vector<
%1 = arm_sve.masked.addi %mask, %0, %c : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: arm_sve.masked.subi {{.*}}: vector<[4]xi1>, vector<
%2 = arm_sve.masked.subi %mask, %1, %d : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: arm_sve.masked.divi_signed
%3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
%3 = arm_sve.masked.divi_signed %mask, %2, %e : vector<[4]xi1>,
vector<[4]xi32>
// CHECK: arm_sve.masked.divi_unsigned
%4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xi32>
return %2 : !arm_sve.vector<4xi32>
%4 = arm_sve.masked.divi_unsigned %mask, %3, %e : vector<[4]xi1>,
vector<[4]xi32>
return %2 : vector<[4]xi32>
}
func @arm_sve_masked_arithf(%a: !arm_sve.vector<4xf32>,
%b: !arm_sve.vector<4xf32>,
%c: !arm_sve.vector<4xf32>,
%d: !arm_sve.vector<4xf32>,
%e: !arm_sve.vector<4xf32>,
%mask: !arm_sve.vector<4xi1>)
-> !arm_sve.vector<4xf32> {
// CHECK: arm_sve.masked.mulf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
%0 = arm_sve.masked.mulf %mask, %a, %b : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xf32>
// CHECK: arm_sve.masked.addf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
%1 = arm_sve.masked.addf %mask, %0, %c : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xf32>
// CHECK: arm_sve.masked.subf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
%2 = arm_sve.masked.subf %mask, %1, %d : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xf32>
// CHECK: arm_sve.masked.divf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector
%3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>,
!arm_sve.vector<4xf32>
return %3 : !arm_sve.vector<4xf32>
}
func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>,
%b: !arm_sve.vector<4xf32>)
-> !arm_sve.vector<4xi1> {
// CHECK: arm_sve.cmpf oeq, {{.*}}: !arm_sve.vector<4xf32>
%0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32>
return %0 : !arm_sve.vector<4xi1>
}
func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
%b: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi1> {
// CHECK: arm_sve.cmpi uge, {{.*}}: !arm_sve.vector<4xi32>
%0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi1>
}
func @arm_sve_memory(%v: !arm_sve.vector<4xi32>,
%m: memref<?xi32>)
-> !arm_sve.vector<4xi32> {
%c0 = arith.constant 0 : index
// CHECK: arm_sve.load {{.*}}: !arm_sve.vector<4xi32> from memref<?xi32>
%0 = arm_sve.load %m[%c0] : !arm_sve.vector<4xi32> from memref<?xi32>
// CHECK: arm_sve.store {{.*}}: !arm_sve.vector<4xi32> to memref<?xi32>
arm_sve.store %v, %m[%c0] : !arm_sve.vector<4xi32> to memref<?xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @get_vector_scale() -> index {
// CHECK: arm_sve.vector_scale : index
%0 = arm_sve.vector_scale : index
return %0 : index
func @arm_sve_masked_arithf(%a: vector<[4]xf32>,
%b: vector<[4]xf32>,
%c: vector<[4]xf32>,
%d: vector<[4]xf32>,
%e: vector<[4]xf32>,
%mask: vector<[4]xi1>)
-> vector<[4]xf32> {
// CHECK: arm_sve.masked.mulf {{.*}}: vector<[4]xi1>, vector<
%0 = arm_sve.masked.mulf %mask, %a, %b : vector<[4]xi1>,
vector<[4]xf32>
// CHECK: arm_sve.masked.addf {{.*}}: vector<[4]xi1>, vector<
%1 = arm_sve.masked.addf %mask, %0, %c : vector<[4]xi1>,
vector<[4]xf32>
// CHECK: arm_sve.masked.subf {{.*}}: vector<[4]xi1>, vector<
%2 = arm_sve.masked.subf %mask, %1, %d : vector<[4]xi1>,
vector<[4]xf32>
// CHECK: arm_sve.masked.divf {{.*}}: vector<[4]xi1>, vector<
%3 = arm_sve.masked.divf %mask, %2, %e : vector<[4]xi1>,
vector<[4]xf32>
return %3 : vector<[4]xf32>
}

View File

@@ -9,3 +9,11 @@
// -----
//===----------------------------------------------------------------------===//
// VectorType
//===----------------------------------------------------------------------===//
// expected-error@+1 {{missing ']' closing set of scalable dimensions}}
func @scalable_vector_arg(%arg0: vector<[4xf32>) { }
// -----

View File

@@ -18,3 +18,19 @@
// An unrealized N-1 conversion.
%result3 = unrealized_conversion_cast %operand, %operand : !foo.type, !foo.type to !bar.tuple_type<!foo.type, !foo.type>
//===----------------------------------------------------------------------===//
// VectorType
//===----------------------------------------------------------------------===//
// A basic 1D scalable vector
%scalable_vector_1d = "foo.op"() : () -> vector<[4]xi32>
// A 2D scalable vector
%scalable_vector_2d = "foo.op"() : () -> vector<[2x2]xf64>
// A 2D scalable vector with fixed-length dimensions
%scalable_vector_2d_mixed = "foo.op"() : () -> vector<2x[4]xbf16>
// A multi-dimensional vector with mixed scalable and fixed-length dimensions
%scalable_vector_multi_mixed = "foo.op"() : () -> vector<2x2x[4x4]xi8>

View File

@@ -578,6 +578,25 @@ func @vector_load_and_store_1d_vector_memref(%memref : memref<200x100xvector<8xf
return
}
// CHECK-LABEL: @vector_load_and_store_scalable_vector_memref
func @vector_load_and_store_scalable_vector_memref(%v: vector<[4]xi32>, %m: memref<?xi32>) -> vector<[4]xi32> {
%c0 = arith.constant 0 : index
// CHECK: vector.load {{.*}}: memref<?xi32>, vector<[4]xi32>
%0 = vector.load %m[%c0] : memref<?xi32>, vector<[4]xi32>
// CHECK: vector.store {{.*}}: memref<?xi32>, vector<[4]xi32>
vector.store %v, %m[%c0] : memref<?xi32>, vector<[4]xi32>
return %0 : vector<[4]xi32>
}
func @vector_load_and_store_1d_scalable_vector_memref(%memref : memref<200x100xvector<8xf32>>,
%i : index, %j : index) {
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
%0 = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
vector.store %0, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
return
}
// CHECK-LABEL: @vector_load_and_store_out_of_bounds
func @vector_load_and_store_out_of_bounds(%memref : memref<7xf32>) {
%c0 = arith.constant 0 : index
@@ -691,3 +710,10 @@ func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
vector<4x16xf32> to f32
return %2 : f32
}
// CHECK-LABEL: @get_vector_scale
func @get_vector_scale() -> index {
// CHECK: vector.vscale
%0 = vector.vscale
return %0 : index
}

View File

@@ -0,0 +1,27 @@
// RUN: mlir-opt %s -convert-vector-to-llvm | mlir-opt | FileCheck %s
// CHECK: vector_scalable_memcopy([[SRC:%arg[0-9]+]]: memref<?xf32>, [[DST:%arg[0-9]+]]
func @vector_scalable_memcopy(%src : memref<?xf32>, %dst : memref<?xf32>, %size : index) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%vs = vector.vscale
%step = arith.muli %c4, %vs : index
// CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref<?xf32> to !llvm.struct<(ptr<f32>
// CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref<?xf32> to !llvm.struct<(ptr<f32>
// CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}}
scf.for %i0 = %c0 to %size step %step {
// CHECK: [[DATAIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64
// CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr<f32>
// CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[DATAIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr<f32> to !llvm.ptr<vector<[4]xf32>>
// CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]]{{.*}}: !llvm.ptr<vector<[4]xf32>>
%0 = vector.load %src[%i0] : memref<?xf32>, vector<[4]xf32>
// CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr<f32>
// CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[DATAIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr<f32> to !llvm.ptr<vector<[4]xf32>>
// CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]]{{.*}}: !llvm.ptr<vector<[4]xf32>>
vector.store %0, %dst[%i0] : memref<?xf32>, vector<[4]xf32>
}
return
}

View File

@@ -1,193 +1,193 @@
// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_sdot
llvm.func @arm_sve_sdot(%arg0: !llvm.vec<?x16 x i8>,
%arg1: !llvm.vec<?x16 x i8>,
%arg2: !llvm.vec<?x4 x i32>)
-> !llvm.vec<?x4 x i32> {
llvm.func @arm_sve_sdot(%arg0: vector<[16]xi8>,
%arg1: vector<[16]xi8>,
%arg2: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4
%0 = "arm_sve.intr.sdot"(%arg2, %arg0, %arg1) :
(!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-> !llvm.vec<?x4 x i32>
llvm.return %0 : !llvm.vec<?x4 x i32>
(vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
-> vector<[4]xi32>
llvm.return %0 : vector<[4]xi32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_smmla
llvm.func @arm_sve_smmla(%arg0: !llvm.vec<?x16 x i8>,
%arg1: !llvm.vec<?x16 x i8>,
%arg2: !llvm.vec<?x4 x i32>)
-> !llvm.vec<?x4 x i32> {
llvm.func @arm_sve_smmla(%arg0: vector<[16]xi8>,
%arg1: vector<[16]xi8>,
%arg2: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4
%0 = "arm_sve.intr.smmla"(%arg2, %arg0, %arg1) :
(!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-> !llvm.vec<?x4 x i32>
llvm.return %0 : !llvm.vec<?x4 x i32>
(vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
-> vector<[4]xi32>
llvm.return %0 : vector<[4]xi32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_udot
llvm.func @arm_sve_udot(%arg0: !llvm.vec<?x16 x i8>,
%arg1: !llvm.vec<?x16 x i8>,
%arg2: !llvm.vec<?x4 x i32>)
-> !llvm.vec<?x4 x i32> {
llvm.func @arm_sve_udot(%arg0: vector<[16]xi8>,
%arg1: vector<[16]xi8>,
%arg2: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4
%0 = "arm_sve.intr.udot"(%arg2, %arg0, %arg1) :
(!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-> !llvm.vec<?x4 x i32>
llvm.return %0 : !llvm.vec<?x4 x i32>
(vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
-> vector<[4]xi32>
llvm.return %0 : vector<[4]xi32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_ummla
llvm.func @arm_sve_ummla(%arg0: !llvm.vec<?x16 x i8>,
%arg1: !llvm.vec<?x16 x i8>,
%arg2: !llvm.vec<?x4 x i32>)
-> !llvm.vec<?x4 x i32> {
llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
%arg1: vector<[16]xi8>,
%arg2: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4
%0 = "arm_sve.intr.ummla"(%arg2, %arg0, %arg1) :
(!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-> !llvm.vec<?x4 x i32>
llvm.return %0 : !llvm.vec<?x4 x i32>
(vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
-> vector<[4]xi32>
llvm.return %0 : vector<[4]xi32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: !llvm.vec<? x 4 x i32>,
%arg1: !llvm.vec<? x 4 x i32>,
%arg2: !llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32> {
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
%arg2: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: mul <vscale x 4 x i32>
%0 = llvm.mul %arg0, %arg1 : !llvm.vec<? x 4 x i32>
%0 = llvm.mul %arg0, %arg1 : vector<[4]xi32>
// CHECK: add <vscale x 4 x i32>
%1 = llvm.add %0, %arg2 : !llvm.vec<? x 4 x i32>
llvm.return %1 : !llvm.vec<? x 4 x i32>
%1 = llvm.add %0, %arg2 : vector<[4]xi32>
llvm.return %1 : vector<[4]xi32>
}
// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_arithf
llvm.func @arm_sve_arithf(%arg0: !llvm.vec<? x 4 x f32>,
%arg1: !llvm.vec<? x 4 x f32>,
%arg2: !llvm.vec<? x 4 x f32>)
-> !llvm.vec<? x 4 x f32> {
llvm.func @arm_sve_arithf(%arg0: vector<[4]xf32>,
%arg1: vector<[4]xf32>,
%arg2: vector<[4]xf32>)
-> vector<[4]xf32> {
// CHECK: fmul <vscale x 4 x float>
%0 = llvm.fmul %arg0, %arg1 : !llvm.vec<? x 4 x f32>
%0 = llvm.fmul %arg0, %arg1 : vector<[4]xf32>
// CHECK: fadd <vscale x 4 x float>
%1 = llvm.fadd %0, %arg2 : !llvm.vec<? x 4 x f32>
llvm.return %1 : !llvm.vec<? x 4 x f32>
%1 = llvm.fadd %0, %arg2 : vector<[4]xf32>
llvm.return %1 : vector<[4]xf32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi_masked
llvm.func @arm_sve_arithi_masked(%arg0: !llvm.vec<? x 4 x i32>,
%arg1: !llvm.vec<? x 4 x i32>,
%arg2: !llvm.vec<? x 4 x i32>,
%arg3: !llvm.vec<? x 4 x i32>,
%arg4: !llvm.vec<? x 4 x i32>,
%arg5: !llvm.vec<? x 4 x i1>)
-> !llvm.vec<? x 4 x i32> {
llvm.func @arm_sve_arithi_masked(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
%arg2: vector<[4]xi32>,
%arg3: vector<[4]xi32>,
%arg4: vector<[4]xi32>,
%arg5: vector<[4]xi1>)
-> vector<[4]xi32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
%0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x i32>,
!llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32>
%0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (vector<[4]xi1>,
vector<[4]xi32>,
vector<[4]xi32>)
-> vector<[4]xi32>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
%1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x i32>,
!llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32>
%1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (vector<[4]xi1>,
vector<[4]xi32>,
vector<[4]xi32>)
-> vector<[4]xi32>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.mul.nxv4i32
%2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x i32>,
!llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32>
%2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (vector<[4]xi1>,
vector<[4]xi32>,
vector<[4]xi32>)
-> vector<[4]xi32>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdiv.nxv4i32
%3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x i32>,
!llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32>
%3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (vector<[4]xi1>,
vector<[4]xi32>,
vector<[4]xi32>)
-> vector<[4]xi32>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udiv.nxv4i32
%4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x i32>,
!llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32>
llvm.return %4 : !llvm.vec<? x 4 x i32>
%4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (vector<[4]xi1>,
vector<[4]xi32>,
vector<[4]xi32>)
-> vector<[4]xi32>
llvm.return %4 : vector<[4]xi32>
}
// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_arithf_masked
llvm.func @arm_sve_arithf_masked(%arg0: !llvm.vec<? x 4 x f32>,
%arg1: !llvm.vec<? x 4 x f32>,
%arg2: !llvm.vec<? x 4 x f32>,
%arg3: !llvm.vec<? x 4 x f32>,
%arg4: !llvm.vec<? x 4 x f32>,
%arg5: !llvm.vec<? x 4 x i1>)
-> !llvm.vec<? x 4 x f32> {
llvm.func @arm_sve_arithf_masked(%arg0: vector<[4]xf32>,
%arg1: vector<[4]xf32>,
%arg2: vector<[4]xf32>,
%arg3: vector<[4]xf32>,
%arg4: vector<[4]xf32>,
%arg5: vector<[4]xi1>)
-> vector<[4]xf32> {
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fadd.nxv4f32
%0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x f32>,
!llvm.vec<? x 4 x f32>)
-> !llvm.vec<? x 4 x f32>
%0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (vector<[4]xi1>,
vector<[4]xf32>,
vector<[4]xf32>)
-> vector<[4]xf32>
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fsub.nxv4f32
%1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x f32>,
!llvm.vec<? x 4 x f32>)
-> !llvm.vec<? x 4 x f32>
%1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (vector<[4]xi1>,
vector<[4]xf32>,
vector<[4]xf32>)
-> vector<[4]xf32>
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fmul.nxv4f32
%2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x f32>,
!llvm.vec<? x 4 x f32>)
-> !llvm.vec<? x 4 x f32>
%2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (vector<[4]xi1>,
vector<[4]xf32>,
vector<[4]xf32>)
-> vector<[4]xf32>
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.fdiv.nxv4f32
%3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x f32>,
!llvm.vec<? x 4 x f32>)
-> !llvm.vec<? x 4 x f32>
llvm.return %3 : !llvm.vec<? x 4 x f32>
%3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (vector<[4]xi1>,
vector<[4]xf32>,
vector<[4]xf32>)
-> vector<[4]xf32>
llvm.return %3 : vector<[4]xf32>
}
// CHECK-LABEL: define <vscale x 4 x i1> @arm_sve_mask_genf
llvm.func @arm_sve_mask_genf(%arg0: !llvm.vec<? x 4 x f32>,
%arg1: !llvm.vec<? x 4 x f32>)
-> !llvm.vec<? x 4 x i1> {
llvm.func @arm_sve_mask_genf(%arg0: vector<[4]xf32>,
%arg1: vector<[4]xf32>)
-> vector<[4]xi1> {
// CHECK: fcmp oeq <vscale x 4 x float>
%0 = llvm.fcmp "oeq" %arg0, %arg1 : !llvm.vec<? x 4 x f32>
llvm.return %0 : !llvm.vec<? x 4 x i1>
%0 = llvm.fcmp "oeq" %arg0, %arg1 : vector<[4]xf32>
llvm.return %0 : vector<[4]xi1>
}
// CHECK-LABEL: define <vscale x 4 x i1> @arm_sve_mask_geni
llvm.func @arm_sve_mask_geni(%arg0: !llvm.vec<? x 4 x i32>,
%arg1: !llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i1> {
llvm.func @arm_sve_mask_geni(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>)
-> vector<[4]xi1> {
// CHECK: icmp uge <vscale x 4 x i32>
%0 = llvm.icmp "uge" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i1>
%0 = llvm.icmp "uge" %arg0, %arg1 : vector<[4]xi32>
llvm.return %0 : vector<[4]xi1>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_abs_diff
llvm.func @arm_sve_abs_diff(%arg0: !llvm.vec<? x 4 x i32>,
%arg1: !llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32> {
llvm.func @arm_sve_abs_diff(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>)
-> vector<[4]xi32> {
// CHECK: sub <vscale x 4 x i32>
%0 = llvm.sub %arg0, %arg0 : !llvm.vec<? x 4 x i32>
%0 = llvm.sub %arg0, %arg0 : vector<[4]xi32>
// CHECK: icmp sge <vscale x 4 x i32>
%1 = llvm.icmp "sge" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
%1 = llvm.icmp "sge" %arg0, %arg1 : vector<[4]xi32>
// CHECK: icmp slt <vscale x 4 x i32>
%2 = llvm.icmp "slt" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
%2 = llvm.icmp "slt" %arg0, %arg1 : vector<[4]xi32>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
%3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x i32>,
!llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32>
%3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (vector<[4]xi1>,
vector<[4]xi32>,
vector<[4]xi32>)
-> vector<[4]xi32>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
%4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x i32>,
!llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32>
%4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (vector<[4]xi1>,
vector<[4]xi32>,
vector<[4]xi32>)
-> vector<[4]xi32>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
%5 = "arm_sve.intr.add"(%1, %0, %3) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x i32>,
!llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32>
%5 = "arm_sve.intr.add"(%1, %0, %3) : (vector<[4]xi1>,
vector<[4]xi32>,
vector<[4]xi32>)
-> vector<[4]xi32>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
%6 = "arm_sve.intr.add"(%2, %5, %4) : (!llvm.vec<? x 4 x i1>,
!llvm.vec<? x 4 x i32>,
!llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32>
llvm.return %6 : !llvm.vec<? x 4 x i32>
%6 = "arm_sve.intr.add"(%2, %5, %4) : (vector<[4]xi1>,
vector<[4]xi32>,
vector<[4]xi32>)
-> vector<[4]xi32>
llvm.return %6 : vector<[4]xi32>
}
// CHECK-LABEL: define void @memcopy
@@ -234,7 +234,7 @@ llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
%12 = llvm.mlir.constant(0 : index) : i64
%13 = llvm.mlir.constant(4 : index) : i64
// CHECK: [[VL:%[0-9]+]] = call i64 @llvm.vscale.i64()
%14 = "arm_sve.vscale"() : () -> i64
%14 = "llvm.intr.vscale"() : () -> i64
// CHECK: mul i64 [[VL]], 4
%15 = llvm.mul %14, %13 : i64
llvm.br ^bb1(%12 : i64)
@@ -249,9 +249,9 @@ llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
// CHECK: etelementptr float, float*
%19 = llvm.getelementptr %18[%16] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: bitcast float* %{{[0-9]+}} to <vscale x 4 x float>*
%20 = llvm.bitcast %19 : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
%20 = llvm.bitcast %19 : !llvm.ptr<f32> to !llvm.ptr<vector<[4]xf32>>
// CHECK: load <vscale x 4 x float>, <vscale x 4 x float>*
%21 = llvm.load %20 : !llvm.ptr<vec<? x 4 x f32>>
%21 = llvm.load %20 : !llvm.ptr<vector<[4]xf32>>
// CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] }
%22 = llvm.extractvalue %11[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
array<1 x i64>,
@@ -259,9 +259,9 @@ llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
// CHECK: getelementptr float, float* %32
%23 = llvm.getelementptr %22[%16] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: bitcast float* %33 to <vscale x 4 x float>*
%24 = llvm.bitcast %23 : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
%24 = llvm.bitcast %23 : !llvm.ptr<f32> to !llvm.ptr<vector<[4]xf32>>
// CHECK: store <vscale x 4 x float> %{{[0-9]+}}, <vscale x 4 x float>* %{{[0-9]+}}
llvm.store %21, %24 : !llvm.ptr<vec<? x 4 x f32>>
llvm.store %21, %24 : !llvm.ptr<vector<[4]xf32>>
%25 = llvm.add %16, %15 : i64
llvm.br ^bb1(%25 : i64)
^bb3:
@@ -271,6 +271,6 @@ llvm.func @memcopy(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>,
// CHECK-LABEL: define i64 @get_vector_scale()
llvm.func @get_vector_scale() -> i64 {
// CHECK: call i64 @llvm.vscale.i64()
%0 = "arm_sve.vscale"() : () -> i64
%0 = "llvm.intr.vscale"() : () -> i64
llvm.return %0 : i64
}