[mlir][linalg] Add vectorization for linalg on tensor ops

Support vectorization of linalg ops using tensor inputs/outputs.

Differential Revision: https://reviews.llvm.org/D93890
This commit is contained in:
Thomas Raoux
2020-12-28 23:55:15 -08:00
parent b76014a4f1
commit cf216670a0
3 changed files with 153 additions and 55 deletions

View File

@@ -119,34 +119,38 @@ static bool isElementwise(Operation *op) {
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
}
static VectorType extractVectorTypeFromScalarView(Value v) {
MemRefType mt = v.getType().cast<MemRefType>();
return mt.getShape().empty()
? VectorType()
: VectorType::get(mt.getShape(), mt.getElementType());
static VectorType extractVectorTypeFromShapedValue(Value v) {
auto st = v.getType().cast<ShapedType>();
if (st.isa<MemRefType>() && st.getShape().empty())
return VectorType();
return VectorType::get(st.getShape(), st.getElementType());
}
static Value transferReadVector(OpBuilder &builder, Value memref) {
static Value transferReadVector(OpBuilder &builder, Value source) {
edsc::ScopedContext scope(builder);
auto memrefType = memref.getType().cast<MemRefType>();
if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
return vector_transfer_read(vectorType, memref, indices);
auto shapedType = source.getType().cast<ShapedType>();
if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
return vector_transfer_read(vectorType, source, indices);
}
return std_load(memref);
return std_load(source);
}
static void transferWriteVector(OpBuilder &builder, Value value, Value memref) {
static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) {
edsc::ScopedContext scope(builder);
auto memrefType = memref.getType().cast<MemRefType>();
if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
Operation *write;
auto shapedType = dest.getType().cast<ShapedType>();
if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0));
if (vectorType != value.getType())
value = vector_broadcast(vectorType, value);
vector_transfer_write(value, memref, indices);
write = vector_transfer_write(value, dest, indices);
} else {
std_store(value, memref);
write = std_store(value, dest);
}
if (!write->getResults().empty())
return write->getResult(0);
return Value();
}
namespace {
@@ -167,10 +171,12 @@ public:
void vectorize(Operation &scalarOp) {
auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
if (yieldOp) {
for (auto outputAndMemref :
llvm::zip(yieldOp.values(), generic.getOutputBuffers())) {
Value vectorValue = vectorize(std::get<0>(outputAndMemref));
transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref));
for (auto outputs : llvm::enumerate(yieldOp.values())) {
Value vectorValue = vectorize(outputs.value());
Value result = transferWriteVector(builder, vectorValue,
generic.getOutput(outputs.index()));
if (result)
results.push_back(result);
}
return;
}
@@ -182,6 +188,8 @@ public:
}
}
llvm::ArrayRef<Value> getResults() { return results; }
private:
// Transforms a scalar value into its vectorized counterpart, recursively
// vectorizing operations as necessary using the underlying builder.
@@ -261,6 +269,7 @@ private:
OpBuilder &builder;
linalg::GenericOp generic;
llvm::DenseMap<Value, Value> valueCache;
SmallVector<Value, 8> results;
};
} // namespace
@@ -271,6 +280,8 @@ static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
for (Operation &scalarOp : op.region().front()) {
vectorizer.vectorize(scalarOp);
}
if (!op->getResults().empty())
op->replaceAllUsesWith(vectorizer.getResults());
}
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@@ -331,32 +342,14 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
auto linalgOp = cast<linalg::LinalgOp>(op);
Value viewA = linalgOp.getInput(0);
Value viewB = linalgOp.getInput(1);
Value viewC = linalgOp.getOutputBuffer(0);
VectorType vtA = extractVectorTypeFromScalarView(viewA);
VectorType vtB = extractVectorTypeFromScalarView(viewB);
VectorType vtC = extractVectorTypeFromScalarView(viewC);
Value zero = std_constant_index(0);
SmallVector<Value, 4> indicesA, indicesB, indicesC;
if (vtA)
indicesA = SmallVector<Value, 4>(vtA.getRank(), zero);
if (vtB)
indicesB = SmallVector<Value, 4>(vtB.getRank(), zero);
if (vtC)
indicesC = SmallVector<Value, 4>(vtC.getRank(), zero);
Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value
: std_load(viewA, indicesA).value;
Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value
: std_load(viewB, indicesB).value;
Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value
: std_load(viewC, indicesC).value;
Value a = transferReadVector(builder, linalgOp.getInput(0));
Value b = transferReadVector(builder, linalgOp.getInput(1));
Value c = transferReadVector(builder, linalgOp.getOutput(0));
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
linalgOp.iterator_types());
if (vtC)
vector_transfer_write(res, viewC, indicesC);
else
std_store(res, viewC, indicesC);
Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0));
if (writeResult)
linalgOp->replaceAllUsesWith(ArrayRef<Value>(writeResult));
}
/// Check whether there is any interleaved use of any `values` between `firstOp`

View File

@@ -2039,28 +2039,28 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
/// Builder that sets padding to zero.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vector, Value memref, ValueRange indices,
VectorType vector, Value source, ValueRange indices,
AffineMap permutationMap,
ArrayRef<bool> maybeMasked) {
Type elemType = memref.getType().cast<MemRefType>().getElementType();
Type elemType = source.getType().cast<ShapedType>().getElementType();
Value padding = builder.create<ConstantOp>(result.location, elemType,
builder.getZeroAttr(elemType));
if (maybeMasked.empty())
return build(builder, result, vector, memref, indices, permutationMap,
return build(builder, result, vector, source, indices, permutationMap,
padding, ArrayAttr());
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
build(builder, result, vector, memref, indices, permutationMap, padding,
build(builder, result, vector, source, indices, permutationMap, padding,
maskedArrayAttr);
}
/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
/// (resp. zero).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value memref,
VectorType vectorType, Value source,
ValueRange indices, ArrayRef<bool> maybeMasked) {
auto permMap = getTransferMinorIdentityMap(
memref.getType().cast<MemRefType>(), vectorType);
build(builder, result, vectorType, memref, indices, permMap, maybeMasked);
source.getType().cast<ShapedType>(), vectorType);
build(builder, result, vectorType, source, indices, permMap, maybeMasked);
}
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
@@ -2251,7 +2251,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<bool> maybeMasked) {
auto vectorType = vector.getType().cast<VectorType>();
auto permMap = getTransferMinorIdentityMap(
source.getType().cast<MemRefType>(), vectorType);
source.getType().cast<ShapedType>(), vectorType);
if (maybeMasked.empty())
return build(builder, result, vector, source, indices, permMap,
ArrayAttr());
@@ -2327,7 +2327,7 @@ static void print(OpAsmPrinter &p, TransferWriteOp op) {
}
static LogicalResult verify(TransferWriteOp op) {
// Consistency of elemental types in memref and vector.
// Consistency of elemental types in shape and vector.
ShapedType shapedType = op.getShapedType();
VectorType vectorType = op.getVectorType();
auto permutationMap = op.permutation_map();

View File

@@ -189,7 +189,7 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
@@ -209,3 +209,108 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
// CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
// CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
// CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
%arg1: tensor<4x256xf32>, %arg2: tensor<256xf32>,
%i: f32) -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>) {
%c1_f32 = constant 1.0 : f32
%r:10 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg1, %arg2: tensor<4x256xf32>, tensor<256xf32>)
outs(
%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 :
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
%arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
%arg14 : f32):
%6 = addf %arg4, %arg6 : f32
%7 = cmpf "ogt", %arg3, %arg6 : f32
%8 = constant 2.0 : f32
%9 = divf %arg5, %i : f32
%10 = exp2 %arg5 : f32
%11 = mulf %arg5, %8 : f32
%12 = rsqrt %arg5 : f32
%13 = select %7, %arg5, %arg6 : f32
%14 = subf %arg5, %arg4 : f32
%15 = tanh %arg5 : f32
linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
f32, f32, f32, f32, f32, f32, f32, f32
} -> tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9:
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,
tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
}
// CHECK-LABEL: func @generic_vectorize_tensor
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x256xf32>, %[[ARG1:.*]]: tensor<4x256xf32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<256xf32>, %[[ARG3:.*]]: f32)
// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32>
// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
// CHECK: %[[R0:.*]] = vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: %[[R1:.*]] = vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: %[[R2:.*]] = vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: %[[R3:.*]] = vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: %[[R4:.*]] = vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: %[[R5:.*]] = vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: %[[R6:.*]] = vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: %[[R7:.*]] = vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: %[[R8:.*]] = vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32>
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
func @matmul_tensors(
%arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>)
-> tensor<8x12xf32> {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
outs(%arg2: tensor<8x12xf32>)
-> tensor<8x12xf32>
return %0 : tensor<8x12xf32>
}
// CHECK-LABEL: func @matmul_tensors
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[V2]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
// CHECK: %[[W:.*]] = vector.transfer_write %[[C]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
// CHECK: return %[[W]] : tensor<8x12xf32>