2019-11-14 15:39:36 -08:00
|
|
|
//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
|
2019-08-12 04:08:26 -07:00
|
|
|
//
|
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
|
//
|
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
|
//
|
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
//
|
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
|
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
2019-12-03 17:51:34 -08:00
|
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
2019-08-19 11:00:47 -07:00
|
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
2019-08-12 04:08:26 -07:00
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
|
#include "mlir/IR/MLIRContext.h"
|
|
|
|
|
#include "mlir/IR/Module.h"
|
|
|
|
|
#include "mlir/IR/Operation.h"
|
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
|
#include "mlir/IR/StandardTypes.h"
|
|
|
|
|
#include "mlir/IR/Types.h"
|
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
|
#include "mlir/Pass/PassManager.h"
|
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
|
|
|
|
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
|
|
|
#include "llvm/IR/Module.h"
|
|
|
|
|
#include "llvm/IR/Type.h"
|
|
|
|
|
#include "llvm/Support/Allocator.h"
|
|
|
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static LLVM::LLVMType getPtrToElementType(T containerType,
|
|
|
|
|
LLVMTypeConverter &lowering) {
|
|
|
|
|
return lowering.convertType(containerType.getElementType())
|
|
|
|
|
.template cast<LLVM::LLVMType>()
|
|
|
|
|
.getPointerTo();
|
|
|
|
|
}
|
|
|
|
|
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
// Helper to reduce vector type by one rank at front.
|
|
|
|
|
static VectorType reducedVectorTypeFront(VectorType tp) {
|
|
|
|
|
assert((tp.getRank() > 1) && "unlowerable vector type");
|
|
|
|
|
return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Helper to reduce vector type by *all* but one rank at back.
|
|
|
|
|
static VectorType reducedVectorTypeBack(VectorType tp) {
|
|
|
|
|
assert((tp.getRank() > 1) && "unlowerable vector type");
|
|
|
|
|
return VectorType::get(tp.getShape().take_back(), tp.getElementType());
|
|
|
|
|
}
|
|
|
|
|
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
// Helper that picks the proper sequence for inserting.
|
|
|
|
|
static Value *insertOne(ConversionPatternRewriter &rewriter,
|
|
|
|
|
LLVMTypeConverter &lowering, Location loc, Value *val1,
|
|
|
|
|
Value *val2, Type llvmType, int64_t rank, int64_t pos) {
|
|
|
|
|
if (rank == 1) {
|
|
|
|
|
auto idxType = rewriter.getIndexType();
|
|
|
|
|
auto constant = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
loc, lowering.convertType(idxType),
|
|
|
|
|
rewriter.getIntegerAttr(idxType, pos));
|
|
|
|
|
return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
|
|
|
|
|
constant);
|
|
|
|
|
}
|
|
|
|
|
return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
|
|
|
|
|
rewriter.getI64ArrayAttr(pos));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Helper that picks the proper sequence for extracting.
|
|
|
|
|
static Value *extractOne(ConversionPatternRewriter &rewriter,
|
|
|
|
|
LLVMTypeConverter &lowering, Location loc, Value *val,
|
|
|
|
|
Type llvmType, int64_t rank, int64_t pos) {
|
|
|
|
|
if (rank == 1) {
|
|
|
|
|
auto idxType = rewriter.getIndexType();
|
|
|
|
|
auto constant = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
loc, lowering.convertType(idxType),
|
|
|
|
|
rewriter.getIntegerAttr(idxType, pos));
|
|
|
|
|
return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
|
|
|
|
|
constant);
|
|
|
|
|
}
|
|
|
|
|
return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
|
|
|
|
|
rewriter.getI64ArrayAttr(pos));
|
|
|
|
|
}
|
|
|
|
|
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
class VectorBroadcastOpConversion : public LLVMOpLowering {
|
|
|
|
|
public:
|
|
|
|
|
explicit VectorBroadcastOpConversion(MLIRContext *context,
|
|
|
|
|
LLVMTypeConverter &typeConverter)
|
|
|
|
|
: LLVMOpLowering(vector::BroadcastOp::getOperationName(), context,
|
|
|
|
|
typeConverter) {}
|
|
|
|
|
|
|
|
|
|
PatternMatchResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto broadcastOp = cast<vector::BroadcastOp>(op);
|
|
|
|
|
VectorType dstVectorType = broadcastOp.getVectorType();
|
|
|
|
|
if (lowering.convertType(dstVectorType) == nullptr)
|
|
|
|
|
return matchFailure();
|
|
|
|
|
// Rewrite when the full vector type can be lowered (which
|
|
|
|
|
// implies all 'reduced' types can be lowered too).
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
VectorType srcVectorType =
|
|
|
|
|
broadcastOp.getSourceType().dyn_cast<VectorType>();
|
|
|
|
|
rewriter.replaceOp(
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
op, expandRanks(adaptor.source(), // source value to be expanded
|
|
|
|
|
op->getLoc(), // location of original broadcast
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
srcVectorType, dstVectorType, rewriter));
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// Expands the given source value over all the ranks, as defined
|
|
|
|
|
// by the source and destination type (a null source type denotes
|
|
|
|
|
// expansion from a scalar value into a vector).
|
|
|
|
|
//
|
|
|
|
|
// TODO(ajcbik): consider replacing this one-pattern lowering
|
|
|
|
|
// with a two-pattern lowering using other vector
|
|
|
|
|
// ops once all insert/extract/shuffle operations
|
|
|
|
|
// are available with lowering implemention.
|
|
|
|
|
//
|
|
|
|
|
Value *expandRanks(Value *value, Location loc, VectorType srcVectorType,
|
|
|
|
|
VectorType dstVectorType,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
assert((dstVectorType != nullptr) && "invalid result type in broadcast");
|
|
|
|
|
// Determine rank of source and destination.
|
|
|
|
|
int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
|
|
|
|
|
int64_t dstRank = dstVectorType.getRank();
|
|
|
|
|
int64_t curDim = dstVectorType.getDimSize(0);
|
|
|
|
|
if (srcRank < dstRank)
|
|
|
|
|
// Duplicate this rank.
|
|
|
|
|
return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
|
|
|
|
|
curDim, rewriter);
|
|
|
|
|
// If all trailing dimensions are the same, the broadcast consists of
|
|
|
|
|
// simply passing through the source value and we are done. Otherwise,
|
|
|
|
|
// any non-matching dimension forces a stretch along this rank.
|
|
|
|
|
assert((srcVectorType != nullptr) && (srcRank > 0) &&
|
|
|
|
|
(srcRank == dstRank) && "invalid rank in broadcast");
|
|
|
|
|
for (int64_t r = 0; r < dstRank; r++) {
|
|
|
|
|
if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
|
|
|
|
|
return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
|
|
|
|
|
curDim, rewriter);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Picks the best way to duplicate a single rank. For the 1-D case, a
|
|
|
|
|
// single insert-elt/shuffle is the most efficient expansion. For higher
|
|
|
|
|
// dimensions, however, we need dim x insert-values on a new broadcast
|
|
|
|
|
// with one less leading dimension, which will be lowered "recursively"
|
|
|
|
|
// to matching LLVM IR.
|
|
|
|
|
// For example:
|
|
|
|
|
// v = broadcast s : f32 to vector<4x2xf32>
|
|
|
|
|
// becomes:
|
|
|
|
|
// x = broadcast s : f32 to vector<2xf32>
|
|
|
|
|
// v = [x,x,x,x]
|
|
|
|
|
// becomes:
|
|
|
|
|
// x = [s,s]
|
|
|
|
|
// v = [x,x,x,x]
|
|
|
|
|
Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType,
|
|
|
|
|
VectorType dstVectorType, int64_t rank, int64_t dim,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Type llvmType = lowering.convertType(dstVectorType);
|
|
|
|
|
assert((llvmType != nullptr) && "unlowerable vector type");
|
|
|
|
|
if (rank == 1) {
|
|
|
|
|
Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
Value *expand =
|
|
|
|
|
insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0);
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
SmallVector<int32_t, 4> zeroValues(dim, 0);
|
|
|
|
|
return rewriter.create<LLVM::ShuffleVectorOp>(
|
|
|
|
|
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
|
|
|
|
|
}
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
Value *expand =
|
|
|
|
|
expandRanks(value, loc, srcVectorType,
|
|
|
|
|
reducedVectorTypeFront(dstVectorType), rewriter);
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
|
|
|
|
|
for (int64_t d = 0; d < dim; ++d) {
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
result =
|
|
|
|
|
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Picks the best way to stretch a single rank. For the 1-D case, a
|
|
|
|
|
// single insert-elt/shuffle is the most efficient expansion when at
|
|
|
|
|
// a stretch. Otherwise, every dimension needs to be expanded
|
|
|
|
|
// individually and individually inserted in the resulting vector.
|
|
|
|
|
// For example:
|
|
|
|
|
// v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32>
|
|
|
|
|
// becomes:
|
|
|
|
|
// a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32>
|
|
|
|
|
// b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32>
|
|
|
|
|
// c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32>
|
|
|
|
|
// d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32>
|
|
|
|
|
// v = [a,b,c,d]
|
|
|
|
|
// becomes:
|
|
|
|
|
// x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32>
|
|
|
|
|
// y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
|
|
|
|
|
// a = [x, y]
|
|
|
|
|
// etc.
|
|
|
|
|
Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType,
|
|
|
|
|
VectorType dstVectorType, int64_t rank, int64_t dim,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Type llvmType = lowering.convertType(dstVectorType);
|
|
|
|
|
assert((llvmType != nullptr) && "unlowerable vector type");
|
|
|
|
|
Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
|
|
|
|
|
bool atStretch = dim != srcVectorType.getDimSize(0);
|
|
|
|
|
if (rank == 1) {
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
assert(atStretch);
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
Value *one =
|
|
|
|
|
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0);
|
|
|
|
|
Value *expand =
|
|
|
|
|
insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0);
|
|
|
|
|
SmallVector<int32_t, 4> zeroValues(dim, 0);
|
|
|
|
|
return rewriter.create<LLVM::ShuffleVectorOp>(
|
|
|
|
|
loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
|
|
|
|
|
}
|
|
|
|
|
VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
|
|
|
|
|
VectorType redDstType = reducedVectorTypeFront(dstVectorType);
|
|
|
|
|
Type redLlvmType = lowering.convertType(redSrcType);
|
|
|
|
|
for (int64_t d = 0; d < dim; ++d) {
|
|
|
|
|
int64_t pos = atStretch ? 0 : d;
|
|
|
|
|
Value *one =
|
|
|
|
|
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos);
|
|
|
|
|
Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
|
|
|
|
|
result =
|
|
|
|
|
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
};
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
class VectorShuffleOpConversion : public LLVMOpLowering {
|
|
|
|
|
public:
|
|
|
|
|
explicit VectorShuffleOpConversion(MLIRContext *context,
|
|
|
|
|
LLVMTypeConverter &typeConverter)
|
|
|
|
|
: LLVMOpLowering(vector::ShuffleOp::getOperationName(), context,
|
|
|
|
|
typeConverter) {}
|
|
|
|
|
|
|
|
|
|
PatternMatchResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
|
|
|
|
|
auto shuffleOp = cast<vector::ShuffleOp>(op);
|
|
|
|
|
auto v1Type = shuffleOp.getV1VectorType();
|
|
|
|
|
auto v2Type = shuffleOp.getV2VectorType();
|
|
|
|
|
auto vectorType = shuffleOp.getVectorType();
|
|
|
|
|
Type llvmType = lowering.convertType(vectorType);
|
|
|
|
|
auto maskArrayAttr = shuffleOp.mask();
|
|
|
|
|
|
|
|
|
|
// Bail if result type cannot be lowered.
|
|
|
|
|
if (!llvmType)
|
|
|
|
|
return matchFailure();
|
|
|
|
|
|
|
|
|
|
// Get rank and dimension sizes.
|
|
|
|
|
int64_t rank = vectorType.getRank();
|
|
|
|
|
assert(v1Type.getRank() == rank);
|
|
|
|
|
assert(v2Type.getRank() == rank);
|
|
|
|
|
int64_t v1Dim = v1Type.getDimSize(0);
|
|
|
|
|
|
|
|
|
|
// For rank 1, where both operands have *exactly* the same vector type,
|
|
|
|
|
// there is direct shuffle support in LLVM. Use it!
|
|
|
|
|
if (rank == 1 && v1Type == v2Type) {
|
|
|
|
|
Value *shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
|
|
|
|
|
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
|
|
|
|
|
rewriter.replaceOp(op, shuffle);
|
|
|
|
|
return matchSuccess();
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
}
|
|
|
|
|
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
// For all other cases, insert the individual values individually.
|
|
|
|
|
Value *insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
|
|
|
|
|
int64_t insPos = 0;
|
|
|
|
|
for (auto en : llvm::enumerate(maskArrayAttr)) {
|
|
|
|
|
int64_t extPos = en.value().cast<IntegerAttr>().getInt();
|
|
|
|
|
Value *value = adaptor.v1();
|
|
|
|
|
if (extPos >= v1Dim) {
|
|
|
|
|
extPos -= v1Dim;
|
|
|
|
|
value = adaptor.v2();
|
|
|
|
|
}
|
|
|
|
|
Value *extract =
|
|
|
|
|
extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos);
|
|
|
|
|
insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType,
|
|
|
|
|
rank, insPos++);
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
}
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
rewriter.replaceOp(op, insert);
|
|
|
|
|
return matchSuccess();
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2019-12-16 09:52:13 -08:00
|
|
|
class VectorExtractElementOpConversion : public LLVMOpLowering {
|
|
|
|
|
public:
|
|
|
|
|
explicit VectorExtractElementOpConversion(MLIRContext *context,
|
|
|
|
|
LLVMTypeConverter &typeConverter)
|
|
|
|
|
: LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
|
|
|
|
|
typeConverter) {}
|
|
|
|
|
|
|
|
|
|
PatternMatchResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
|
|
|
|
|
auto extractEltOp = cast<vector::ExtractElementOp>(op);
|
|
|
|
|
auto vectorType = extractEltOp.getVectorType();
|
|
|
|
|
auto llvmType = lowering.convertType(vectorType.getElementType());
|
|
|
|
|
|
|
|
|
|
// Bail if result type cannot be lowered.
|
|
|
|
|
if (!llvmType)
|
|
|
|
|
return matchFailure();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
|
|
|
|
op, llvmType, adaptor.vector(), adaptor.position());
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
class VectorExtractOpConversion : public LLVMOpLowering {
|
2019-08-12 04:08:26 -07:00
|
|
|
public:
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
explicit VectorExtractOpConversion(MLIRContext *context,
|
|
|
|
|
LLVMTypeConverter &typeConverter)
|
2019-12-06 12:38:52 -08:00
|
|
|
: LLVMOpLowering(vector::ExtractOp::getOperationName(), context,
|
2019-11-22 07:52:02 -08:00
|
|
|
typeConverter) {}
|
2019-08-12 04:08:26 -07:00
|
|
|
|
|
|
|
|
PatternMatchResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
2019-12-06 12:38:52 -08:00
|
|
|
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
|
|
|
|
|
auto extractOp = cast<vector::ExtractOp>(op);
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
auto vectorType = extractOp.getVectorType();
|
2019-08-12 04:08:26 -07:00
|
|
|
auto resultType = extractOp.getResult()->getType();
|
|
|
|
|
auto llvmResultType = lowering.convertType(resultType);
|
|
|
|
|
auto positionArrayAttr = extractOp.position();
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
|
|
|
|
|
// Bail if result type cannot be lowered.
|
|
|
|
|
if (!llvmResultType)
|
|
|
|
|
return matchFailure();
|
|
|
|
|
|
2019-08-12 04:08:26 -07:00
|
|
|
// One-shot extraction of vector from array (only requires extractvalue).
|
|
|
|
|
if (resultType.isa<VectorType>()) {
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
|
|
|
|
|
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
|
2019-08-12 04:08:26 -07:00
|
|
|
rewriter.replaceOp(op, extracted);
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
// Potential extraction of 1-D vector from array.
|
2019-08-12 04:08:26 -07:00
|
|
|
auto *context = op->getContext();
|
|
|
|
|
Value *extracted = adaptor.vector();
|
|
|
|
|
auto positionAttrs = positionArrayAttr.getValue();
|
|
|
|
|
if (positionAttrs.size() > 1) {
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
auto oneDVectorType = reducedVectorTypeBack(vectorType);
|
2019-08-12 04:08:26 -07:00
|
|
|
auto nMinusOnePositionAttrs =
|
|
|
|
|
ArrayAttr::get(positionAttrs.drop_back(), context);
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
|
|
|
|
loc, lowering.convertType(oneDVectorType), extracted,
|
|
|
|
|
nMinusOnePositionAttrs);
|
2019-08-12 04:08:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Remaining extraction of element from 1-D LLVM vector
|
|
|
|
|
auto position = positionAttrs.back().cast<IntegerAttr>();
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
|
|
|
|
|
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
|
2019-08-12 04:08:26 -07:00
|
|
|
extracted =
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
|
2019-08-12 04:08:26 -07:00
|
|
|
rewriter.replaceOp(op, extracted);
|
|
|
|
|
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2019-12-16 09:52:13 -08:00
|
|
|
class VectorInsertElementOpConversion : public LLVMOpLowering {
|
|
|
|
|
public:
|
|
|
|
|
explicit VectorInsertElementOpConversion(MLIRContext *context,
|
|
|
|
|
LLVMTypeConverter &typeConverter)
|
|
|
|
|
: LLVMOpLowering(vector::InsertElementOp::getOperationName(), context,
|
|
|
|
|
typeConverter) {}
|
|
|
|
|
|
|
|
|
|
PatternMatchResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
|
|
|
|
|
auto insertEltOp = cast<vector::InsertElementOp>(op);
|
|
|
|
|
auto vectorType = insertEltOp.getDestVectorType();
|
|
|
|
|
auto llvmType = lowering.convertType(vectorType);
|
|
|
|
|
|
|
|
|
|
// Bail if result type cannot be lowered.
|
|
|
|
|
if (!llvmType)
|
|
|
|
|
return matchFailure();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
|
|
|
|
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
[VectorOps] Add lowering of vector.insert to LLVM IR
For example, an insert
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
becomes
%0 = llvm.mlir.constant(3 : i32) : !llvm.i32
%1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">
A more elaborate example, inserting an element in a higher dimension
vector
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
becomes
%0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
%1 = llvm.mlir.constant(15 : i32) : !llvm.i32
%2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
%3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
PiperOrigin-RevId: 284882443
2019-12-10 17:12:11 -08:00
|
|
|
class VectorInsertOpConversion : public LLVMOpLowering {
|
|
|
|
|
public:
|
|
|
|
|
explicit VectorInsertOpConversion(MLIRContext *context,
|
|
|
|
|
LLVMTypeConverter &typeConverter)
|
|
|
|
|
: LLVMOpLowering(vector::InsertOp::getOperationName(), context,
|
|
|
|
|
typeConverter) {}
|
|
|
|
|
|
|
|
|
|
PatternMatchResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto adaptor = vector::InsertOpOperandAdaptor(operands);
|
|
|
|
|
auto insertOp = cast<vector::InsertOp>(op);
|
|
|
|
|
auto sourceType = insertOp.getSourceType();
|
|
|
|
|
auto destVectorType = insertOp.getDestVectorType();
|
|
|
|
|
auto llvmResultType = lowering.convertType(destVectorType);
|
|
|
|
|
auto positionArrayAttr = insertOp.position();
|
|
|
|
|
|
|
|
|
|
// Bail if result type cannot be lowered.
|
|
|
|
|
if (!llvmResultType)
|
|
|
|
|
return matchFailure();
|
|
|
|
|
|
|
|
|
|
// One-shot insertion of a vector into an array (only requires insertvalue).
|
|
|
|
|
if (sourceType.isa<VectorType>()) {
|
|
|
|
|
Value *inserted = rewriter.create<LLVM::InsertValueOp>(
|
|
|
|
|
loc, llvmResultType, adaptor.dest(), adaptor.source(),
|
|
|
|
|
positionArrayAttr);
|
|
|
|
|
rewriter.replaceOp(op, inserted);
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Potential extraction of 1-D vector from array.
|
|
|
|
|
auto *context = op->getContext();
|
|
|
|
|
Value *extracted = adaptor.dest();
|
|
|
|
|
auto positionAttrs = positionArrayAttr.getValue();
|
|
|
|
|
auto position = positionAttrs.back().cast<IntegerAttr>();
|
|
|
|
|
auto oneDVectorType = destVectorType;
|
|
|
|
|
if (positionAttrs.size() > 1) {
|
|
|
|
|
oneDVectorType = reducedVectorTypeBack(destVectorType);
|
|
|
|
|
auto nMinusOnePositionAttrs =
|
|
|
|
|
ArrayAttr::get(positionAttrs.drop_back(), context);
|
|
|
|
|
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
|
|
|
|
loc, lowering.convertType(oneDVectorType), extracted,
|
|
|
|
|
nMinusOnePositionAttrs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Insertion of an element into a 1-D LLVM vector.
|
|
|
|
|
auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
|
|
|
|
|
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
|
|
|
|
|
Value *inserted = rewriter.create<LLVM::InsertElementOp>(
|
|
|
|
|
loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
|
|
|
|
|
constant);
|
|
|
|
|
|
|
|
|
|
// Potential insertion of resulting 1-D vector into array.
|
|
|
|
|
if (positionAttrs.size() > 1) {
|
|
|
|
|
auto nMinusOnePositionAttrs =
|
|
|
|
|
ArrayAttr::get(positionAttrs.drop_back(), context);
|
|
|
|
|
inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
|
|
|
|
|
adaptor.dest(), inserted,
|
|
|
|
|
nMinusOnePositionAttrs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, inserted);
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2019-11-18 10:38:35 -08:00
|
|
|
class VectorOuterProductOpConversion : public LLVMOpLowering {
|
2019-08-12 04:08:26 -07:00
|
|
|
public:
|
2019-11-18 10:38:35 -08:00
|
|
|
explicit VectorOuterProductOpConversion(MLIRContext *context,
|
|
|
|
|
LLVMTypeConverter &typeConverter)
|
2019-11-22 07:52:02 -08:00
|
|
|
: LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
|
|
|
|
|
typeConverter) {}
|
2019-08-12 04:08:26 -07:00
|
|
|
|
|
|
|
|
PatternMatchResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
2019-11-22 07:52:02 -08:00
|
|
|
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
|
2019-08-12 04:08:26 -07:00
|
|
|
auto *ctx = op->getContext();
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
|
|
|
|
|
auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
|
|
|
|
|
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
|
|
|
|
|
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
|
2019-08-12 04:08:26 -07:00
|
|
|
auto llvmArrayOfVectType = lowering.convertType(
|
2019-11-22 07:52:02 -08:00
|
|
|
cast<vector::OuterProductOp>(op).getResult()->getType());
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
|
|
|
|
|
Value *a = adaptor.lhs(), *b = adaptor.rhs();
|
|
|
|
|
Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
|
|
|
|
|
SmallVector<Value *, 8> lhs, accs;
|
|
|
|
|
lhs.reserve(rankLHS);
|
|
|
|
|
accs.reserve(rankLHS);
|
|
|
|
|
for (unsigned d = 0, e = rankLHS; d < e; ++d) {
|
|
|
|
|
// shufflevector explicitly requires i32.
|
|
|
|
|
auto attr = rewriter.getI32IntegerAttr(d);
|
|
|
|
|
SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
|
|
|
|
|
auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
|
|
|
|
|
Value *aD = nullptr, *accD = nullptr;
|
|
|
|
|
// 1. Broadcast the element a[d] into vector aD.
|
|
|
|
|
aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
|
|
|
|
|
// 2. If acc is present, extract 1-d vector acc[d] into accD.
|
|
|
|
|
if (acc)
|
2019-09-16 03:30:33 -07:00
|
|
|
accD = rewriter.create<LLVM::ExtractValueOp>(
|
|
|
|
|
loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
// 3. Compute aD outer b (plus accD, if relevant).
|
|
|
|
|
Value *aOuterbD =
|
2019-10-11 05:38:10 -07:00
|
|
|
accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
.getResult()
|
|
|
|
|
: rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
|
|
|
|
|
// 4. Insert as value `d` in the descriptor.
|
2019-09-16 03:30:33 -07:00
|
|
|
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
|
|
|
|
|
desc, aOuterbD,
|
|
|
|
|
rewriter.getI64ArrayAttr(d));
|
2019-08-12 04:08:26 -07:00
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, desc);
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2019-10-18 13:48:26 -07:00
|
|
|
class VectorTypeCastOpConversion : public LLVMOpLowering {
|
|
|
|
|
public:
|
|
|
|
|
explicit VectorTypeCastOpConversion(MLIRContext *context,
|
|
|
|
|
LLVMTypeConverter &typeConverter)
|
2019-11-22 07:52:02 -08:00
|
|
|
: LLVMOpLowering(vector::TypeCastOp::getOperationName(), context,
|
2019-10-18 13:48:26 -07:00
|
|
|
typeConverter) {}
|
|
|
|
|
|
|
|
|
|
PatternMatchResult
|
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op->getLoc();
|
2019-11-22 07:52:02 -08:00
|
|
|
vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
|
2019-10-18 13:48:26 -07:00
|
|
|
MemRefType sourceMemRefType =
|
|
|
|
|
castOp.getOperand()->getType().cast<MemRefType>();
|
|
|
|
|
MemRefType targetMemRefType =
|
|
|
|
|
castOp.getResult()->getType().cast<MemRefType>();
|
|
|
|
|
|
|
|
|
|
// Only static shape casts supported atm.
|
|
|
|
|
if (!sourceMemRefType.hasStaticShape() ||
|
|
|
|
|
!targetMemRefType.hasStaticShape())
|
|
|
|
|
return matchFailure();
|
|
|
|
|
|
|
|
|
|
auto llvmSourceDescriptorTy =
|
2019-11-14 09:05:11 -08:00
|
|
|
operands[0]->getType().dyn_cast<LLVM::LLVMType>();
|
2019-10-18 13:48:26 -07:00
|
|
|
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
|
|
|
|
|
return matchFailure();
|
2019-11-14 09:05:11 -08:00
|
|
|
MemRefDescriptor sourceMemRef(operands[0]);
|
2019-10-18 13:48:26 -07:00
|
|
|
|
|
|
|
|
auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
|
|
|
|
|
.dyn_cast_or_null<LLVM::LLVMType>();
|
|
|
|
|
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
|
|
|
|
return matchFailure();
|
|
|
|
|
|
|
|
|
|
int64_t offset;
|
|
|
|
|
SmallVector<int64_t, 4> strides;
|
|
|
|
|
auto successStrides =
|
2019-11-14 08:10:36 -08:00
|
|
|
getStridesAndOffset(sourceMemRefType, strides, offset);
|
2019-10-18 13:48:26 -07:00
|
|
|
bool isContiguous = (strides.back() == 1);
|
|
|
|
|
if (isContiguous) {
|
2019-11-14 08:10:36 -08:00
|
|
|
auto sizes = sourceMemRefType.getShape();
|
2019-10-18 13:48:26 -07:00
|
|
|
for (int index = 0, e = strides.size() - 2; index < e; ++index) {
|
|
|
|
|
if (strides[index] != strides[index + 1] * sizes[index + 1]) {
|
|
|
|
|
isContiguous = false;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2019-11-14 08:10:36 -08:00
|
|
|
// Only contiguous source tensors supported atm.
|
2019-10-18 13:48:26 -07:00
|
|
|
if (failed(successStrides) || !isContiguous)
|
|
|
|
|
return matchFailure();
|
|
|
|
|
|
|
|
|
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
|
|
|
|
|
|
|
|
|
|
// Create descriptor.
|
2019-11-14 09:05:11 -08:00
|
|
|
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
|
|
|
|
Type llvmTargetElementTy = desc.getElementType();
|
2019-11-12 07:06:18 -08:00
|
|
|
// Set allocated ptr.
|
2019-11-14 09:05:11 -08:00
|
|
|
Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
|
2019-11-12 07:06:18 -08:00
|
|
|
allocated =
|
|
|
|
|
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
|
2019-11-14 09:05:11 -08:00
|
|
|
desc.setAllocatedPtr(rewriter, loc, allocated);
|
|
|
|
|
// Set aligned ptr.
|
|
|
|
|
Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
|
2019-10-18 13:48:26 -07:00
|
|
|
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
|
2019-11-14 09:05:11 -08:00
|
|
|
desc.setAlignedPtr(rewriter, loc, ptr);
|
2019-10-18 13:48:26 -07:00
|
|
|
// Fill offset 0.
|
|
|
|
|
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
|
|
|
|
|
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
|
2019-11-14 09:05:11 -08:00
|
|
|
desc.setOffset(rewriter, loc, zero);
|
|
|
|
|
|
2019-10-18 13:48:26 -07:00
|
|
|
// Fill size and stride descriptors in memref.
|
|
|
|
|
for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
|
|
|
|
|
int64_t index = indexedSize.index();
|
|
|
|
|
auto sizeAttr =
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
|
|
|
|
|
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
|
2019-11-14 09:05:11 -08:00
|
|
|
desc.setSize(rewriter, loc, index, size);
|
2019-10-18 13:48:26 -07:00
|
|
|
auto strideAttr =
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
|
|
|
|
|
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
|
2019-11-14 09:05:11 -08:00
|
|
|
desc.setStride(rewriter, loc, index, stride);
|
2019-10-18 13:48:26 -07:00
|
|
|
}
|
|
|
|
|
|
2019-11-14 09:05:11 -08:00
|
|
|
rewriter.replaceOp(op, {desc});
|
2019-10-18 13:48:26 -07:00
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2019-08-12 04:08:26 -07:00
|
|
|
/// Populate the given list with patterns that convert from Vector to LLVM.
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
void mlir::populateVectorToLLVMConversionPatterns(
|
|
|
|
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
|
2019-12-16 09:52:13 -08:00
|
|
|
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
|
|
|
|
VectorInsertElementOpConversion, VectorInsertOpConversion,
|
[VectorOps] Add lowering of vector.shuffle to LLVM IR
For example, a shuffle
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xf32>
becomes a direct LLVM shuffle
0 = llvm.shufflevector %arg0, %arg1 [0 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
but
%1 = vector.shuffle %a, %b[1 : i32, 0 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
becomes the more elaborate (note the index permutation that drives
argument selection for the extract operations)
%0 = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
%1 = llvm.extractvalue %arg1[0] : !llvm<"[2 x <4 x float>]">
%2 = llvm.insertvalue %1, %0[0] : !llvm<"[3 x <4 x float>]">
%3 = llvm.extractvalue %arg0[0] : !llvm<"[1 x <4 x float>]">
%4 = llvm.insertvalue %3, %2[1] : !llvm<"[3 x <4 x float>]">
%5 = llvm.extractvalue %arg1[1] : !llvm<"[2 x <4 x float>]">
%6 = llvm.insertvalue %5, %4[2] : !llvm<"[3 x <4 x float>]">
PiperOrigin-RevId: 285268164
2019-12-12 14:11:27 -08:00
|
|
|
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
converter.getDialect()->getContext(), converter);
|
2019-08-12 04:08:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
|
2019-09-23 02:33:51 -07:00
|
|
|
void runOnModule() override;
|
2019-08-12 04:08:26 -07:00
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void LowerVectorToLLVMPass::runOnModule() {
|
|
|
|
|
// Convert to the LLVM IR dialect using the converter defined above.
|
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
|
LLVMTypeConverter converter(&getContext());
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
populateVectorToLLVMConversionPatterns(converter, patterns);
|
2019-08-12 04:08:26 -07:00
|
|
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
|
|
|
|
|
|
|
|
|
ConversionTarget target(getContext());
|
|
|
|
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
|
|
|
|
target.addDynamicallyLegalOp<FuncOp>(
|
|
|
|
|
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
|
|
|
|
if (failed(
|
|
|
|
|
applyPartialConversion(getModule(), target, patterns, &converter))) {
|
|
|
|
|
signalPassFailure();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-09-13 13:33:46 -07:00
|
|
|
OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
|
2019-08-12 04:08:26 -07:00
|
|
|
return new LowerVectorToLLVMPass();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static PassRegistration<LowerVectorToLLVMPass>
|
2019-10-18 13:48:26 -07:00
|
|
|
pass("convert-vector-to-llvm",
|
2019-08-12 04:08:26 -07:00
|
|
|
"Lower the operations from the vector dialect into the LLVM dialect");
|