mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[mlir][linalg] Add vectorization for linalg on tensor ops
Support vectorization of linalg ops using tensor inputs/outputs. Differential Revision: https://reviews.llvm.org/D93890
This commit is contained in:
@@ -119,34 +119,38 @@ static bool isElementwise(Operation *op) {
|
||||
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
|
||||
}
|
||||
|
||||
static VectorType extractVectorTypeFromScalarView(Value v) {
|
||||
MemRefType mt = v.getType().cast<MemRefType>();
|
||||
return mt.getShape().empty()
|
||||
? VectorType()
|
||||
: VectorType::get(mt.getShape(), mt.getElementType());
|
||||
static VectorType extractVectorTypeFromShapedValue(Value v) {
|
||||
auto st = v.getType().cast<ShapedType>();
|
||||
if (st.isa<MemRefType>() && st.getShape().empty())
|
||||
return VectorType();
|
||||
return VectorType::get(st.getShape(), st.getElementType());
|
||||
}
|
||||
|
||||
static Value transferReadVector(OpBuilder &builder, Value memref) {
|
||||
static Value transferReadVector(OpBuilder &builder, Value source) {
|
||||
edsc::ScopedContext scope(builder);
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
|
||||
SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
|
||||
return vector_transfer_read(vectorType, memref, indices);
|
||||
auto shapedType = source.getType().cast<ShapedType>();
|
||||
if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
|
||||
SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
|
||||
return vector_transfer_read(vectorType, source, indices);
|
||||
}
|
||||
return std_load(memref);
|
||||
return std_load(source);
|
||||
}
|
||||
|
||||
static void transferWriteVector(OpBuilder &builder, Value value, Value memref) {
|
||||
static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) {
|
||||
edsc::ScopedContext scope(builder);
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
|
||||
SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
|
||||
Operation *write;
|
||||
auto shapedType = dest.getType().cast<ShapedType>();
|
||||
if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
|
||||
SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
|
||||
if (vectorType != value.getType())
|
||||
value = vector_broadcast(vectorType, value);
|
||||
vector_transfer_write(value, memref, indices);
|
||||
write = vector_transfer_write(value, dest, indices);
|
||||
} else {
|
||||
std_store(value, memref);
|
||||
write = std_store(value, dest);
|
||||
}
|
||||
if (!write->getResults().empty())
|
||||
return write->getResult(0);
|
||||
return Value();
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -167,10 +171,12 @@ public:
|
||||
void vectorize(Operation &scalarOp) {
|
||||
auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
|
||||
if (yieldOp) {
|
||||
for (auto outputAndMemref :
|
||||
llvm::zip(yieldOp.values(), generic.getOutputBuffers())) {
|
||||
Value vectorValue = vectorize(std::get<0>(outputAndMemref));
|
||||
transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref));
|
||||
for (auto outputs : llvm::enumerate(yieldOp.values())) {
|
||||
Value vectorValue = vectorize(outputs.value());
|
||||
Value result = transferWriteVector(builder, vectorValue,
|
||||
generic.getOutput(outputs.index()));
|
||||
if (result)
|
||||
results.push_back(result);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -182,6 +188,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
llvm::ArrayRef<Value> getResults() { return results; }
|
||||
|
||||
private:
|
||||
// Transforms a scalar value into its vectorized counterpart, recursively
|
||||
// vectorizing operations as necessary using the underlying builder.
|
||||
@@ -261,6 +269,7 @@ private:
|
||||
OpBuilder &builder;
|
||||
linalg::GenericOp generic;
|
||||
llvm::DenseMap<Value, Value> valueCache;
|
||||
SmallVector<Value, 8> results;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@@ -271,6 +280,8 @@ static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
|
||||
for (Operation &scalarOp : op.region().front()) {
|
||||
vectorizer.vectorize(scalarOp);
|
||||
}
|
||||
if (!op->getResults().empty())
|
||||
op->replaceAllUsesWith(vectorizer.getResults());
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
@@ -331,32 +342,14 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
|
||||
LLVM_DEBUG(dbgs() << dbgPref
|
||||
<< "Rewrite linalg op as vector.contract: " << *op);
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
Value viewA = linalgOp.getInput(0);
|
||||
Value viewB = linalgOp.getInput(1);
|
||||
Value viewC = linalgOp.getOutputBuffer(0);
|
||||
VectorType vtA = extractVectorTypeFromScalarView(viewA);
|
||||
VectorType vtB = extractVectorTypeFromScalarView(viewB);
|
||||
VectorType vtC = extractVectorTypeFromScalarView(viewC);
|
||||
Value zero = std_constant_index(0);
|
||||
SmallVector<Value, 4> indicesA, indicesB, indicesC;
|
||||
if (vtA)
|
||||
indicesA = SmallVector<Value, 4>(vtA.getRank(), zero);
|
||||
if (vtB)
|
||||
indicesB = SmallVector<Value, 4>(vtB.getRank(), zero);
|
||||
if (vtC)
|
||||
indicesC = SmallVector<Value, 4>(vtC.getRank(), zero);
|
||||
Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value
|
||||
: std_load(viewA, indicesA).value;
|
||||
Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value
|
||||
: std_load(viewB, indicesB).value;
|
||||
Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value
|
||||
: std_load(viewC, indicesC).value;
|
||||
Value a = transferReadVector(builder, linalgOp.getInput(0));
|
||||
Value b = transferReadVector(builder, linalgOp.getInput(1));
|
||||
Value c = transferReadVector(builder, linalgOp.getOutput(0));
|
||||
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
|
||||
linalgOp.iterator_types());
|
||||
if (vtC)
|
||||
vector_transfer_write(res, viewC, indicesC);
|
||||
else
|
||||
std_store(res, viewC, indicesC);
|
||||
Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0));
|
||||
if (writeResult)
|
||||
linalgOp->replaceAllUsesWith(ArrayRef<Value>(writeResult));
|
||||
}
|
||||
|
||||
/// Check whether there is any interleaved use of any `values` between `firstOp`
|
||||
|
||||
@@ -2039,28 +2039,28 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
|
||||
|
||||
/// Builder that sets padding to zero.
|
||||
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
||||
VectorType vector, Value memref, ValueRange indices,
|
||||
VectorType vector, Value source, ValueRange indices,
|
||||
AffineMap permutationMap,
|
||||
ArrayRef<bool> maybeMasked) {
|
||||
Type elemType = memref.getType().cast<MemRefType>().getElementType();
|
||||
Type elemType = source.getType().cast<ShapedType>().getElementType();
|
||||
Value padding = builder.create<ConstantOp>(result.location, elemType,
|
||||
builder.getZeroAttr(elemType));
|
||||
if (maybeMasked.empty())
|
||||
return build(builder, result, vector, memref, indices, permutationMap,
|
||||
return build(builder, result, vector, source, indices, permutationMap,
|
||||
padding, ArrayAttr());
|
||||
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
|
||||
build(builder, result, vector, memref, indices, permutationMap, padding,
|
||||
build(builder, result, vector, source, indices, permutationMap, padding,
|
||||
maskedArrayAttr);
|
||||
}
|
||||
|
||||
/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
|
||||
/// (resp. zero).
|
||||
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
||||
VectorType vectorType, Value memref,
|
||||
VectorType vectorType, Value source,
|
||||
ValueRange indices, ArrayRef<bool> maybeMasked) {
|
||||
auto permMap = getTransferMinorIdentityMap(
|
||||
memref.getType().cast<MemRefType>(), vectorType);
|
||||
build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
|
||||
source.getType().cast<ShapedType>(), vectorType);
|
||||
build(builder, result, vectorType, source, indices, permMap, maybeMasked);
|
||||
}
|
||||
|
||||
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
|
||||
@@ -2251,7 +2251,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
ArrayRef<bool> maybeMasked) {
|
||||
auto vectorType = vector.getType().cast<VectorType>();
|
||||
auto permMap = getTransferMinorIdentityMap(
|
||||
source.getType().cast<MemRefType>(), vectorType);
|
||||
source.getType().cast<ShapedType>(), vectorType);
|
||||
if (maybeMasked.empty())
|
||||
return build(builder, result, vector, source, indices, permMap,
|
||||
ArrayAttr());
|
||||
@@ -2327,7 +2327,7 @@ static void print(OpAsmPrinter &p, TransferWriteOp op) {
|
||||
}
|
||||
|
||||
static LogicalResult verify(TransferWriteOp op) {
|
||||
// Consistency of elemental types in memref and vector.
|
||||
// Consistency of elemental types in shape and vector.
|
||||
ShapedType shapedType = op.getShapedType();
|
||||
VectorType vectorType = op.getVectorType();
|
||||
auto permutationMap = op.permutation_map();
|
||||
|
||||
@@ -189,7 +189,7 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
|
||||
// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
|
||||
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
|
||||
// CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
|
||||
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
|
||||
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
|
||||
// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
|
||||
// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
|
||||
// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
|
||||
@@ -209,3 +209,108 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
|
||||
// CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
|
||||
// CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
|
||||
// CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
|
||||
|
||||
func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
|
||||
%arg1: tensor<4x256xf32>, %arg2: tensor<256xf32>,
|
||||
%i: f32) -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
|
||||
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
|
||||
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>) {
|
||||
%c1_f32 = constant 1.0 : f32
|
||||
%r:10 = linalg.generic {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg1, %arg2: tensor<4x256xf32>, tensor<256xf32>)
|
||||
outs(
|
||||
%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 :
|
||||
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
|
||||
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
|
||||
tensor<4x256xf32>, tensor<4x256xf32>) {
|
||||
^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
|
||||
%arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
|
||||
%arg14 : f32):
|
||||
%6 = addf %arg4, %arg6 : f32
|
||||
%7 = cmpf "ogt", %arg3, %arg6 : f32
|
||||
%8 = constant 2.0 : f32
|
||||
%9 = divf %arg5, %i : f32
|
||||
%10 = exp2 %arg5 : f32
|
||||
%11 = mulf %arg5, %8 : f32
|
||||
%12 = rsqrt %arg5 : f32
|
||||
%13 = select %7, %arg5, %arg6 : f32
|
||||
%14 = subf %arg5, %arg4 : f32
|
||||
%15 = tanh %arg5 : f32
|
||||
linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
|
||||
f32, f32, f32, f32, f32, f32, f32, f32
|
||||
} -> tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
|
||||
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
|
||||
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
|
||||
return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9:
|
||||
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
|
||||
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
|
||||
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @generic_vectorize_tensor
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32)
|
||||
// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
|
||||
// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
|
||||
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
|
||||
// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
|
||||
// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
|
||||
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
|
||||
// CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
|
||||
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
|
||||
// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
|
||||
// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
|
||||
// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
|
||||
// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
|
||||
// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
|
||||
// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
|
||||
// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
|
||||
// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
|
||||
// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
|
||||
// CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
|
||||
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
|
||||
|
||||
func @matmul_tensors(
|
||||
%arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
|
||||
-> tensor<8x12xf32> {
|
||||
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
|
||||
outs(%arg2: tensor<8x12xf32>)
|
||||
-> tensor<8x12xf32>
|
||||
return %0 : tensor<8x12xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @matmul_tensors
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
|
||||
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
|
||||
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
|
||||
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
|
||||
// CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
|
||||
// CHECK: return %[[W]] : tensor<8x12xf32>
|
||||
|
||||
Reference in New Issue
Block a user