Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
//===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===//
|
2018-12-14 09:31:17 -08:00
|
|
|
//
|
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
|
//
|
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
|
//
|
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
//
|
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
// =============================================================================
|
|
|
|
|
//
|
|
|
|
|
// This file implements convenience types for working with super-vectorization
|
|
|
|
|
// operations, in particular super-vector loads and stores.
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
2018-12-14 09:31:17 -08:00
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2019-08-09 05:58:19 -07:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2018-12-14 09:31:17 -08:00
|
|
|
#include "mlir/Support/LLVM.h"
|
2019-08-09 05:58:19 -07:00
|
|
|
|
2018-12-14 09:31:17 -08:00
|
|
|
using namespace mlir;
|
2019-08-09 05:58:19 -07:00
|
|
|
using namespace mlir::vector;
|
2018-12-14 09:31:17 -08:00
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
// VectorOpsDialect
|
2018-12-14 09:31:17 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-08-09 05:58:19 -07:00
|
|
|
mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
|
|
|
|
|
: Dialect(getDialectNamespace(), context) {
|
|
|
|
|
addOperations<
|
|
|
|
|
#define GET_OP_LIST
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
|
2019-08-09 05:58:19 -07:00
|
|
|
>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-18 10:38:35 -08:00
|
|
|
// VectorExtractElementOp
|
2019-08-09 05:58:19 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-11-18 10:38:35 -08:00
|
|
|
static Type inferExtractOpResultType(VectorType vectorType,
|
|
|
|
|
ArrayAttr position) {
|
|
|
|
|
if (static_cast<int64_t>(position.size()) == vectorType.getRank())
|
|
|
|
|
return vectorType.getElementType();
|
|
|
|
|
return VectorType::get(vectorType.getShape().drop_front(position.size()),
|
|
|
|
|
vectorType.getElementType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VectorExtractElementOp::build(Builder *builder, OperationState &result,
|
|
|
|
|
Value *source, ArrayRef<int32_t> position) {
|
|
|
|
|
result.addOperands(source);
|
|
|
|
|
auto positionAttr = builder->getI32ArrayAttr(position);
|
|
|
|
|
result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
|
|
|
|
|
positionAttr));
|
|
|
|
|
result.addAttribute(getPositionAttrName(), positionAttr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void print(OpAsmPrinter &p, VectorExtractElementOp op) {
|
2019-09-20 20:43:02 -07:00
|
|
|
p << op.getOperationName() << " " << *op.vector() << op.position();
|
|
|
|
|
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
|
|
|
|
p << " : " << op.vector()->getType();
|
2019-08-09 05:58:19 -07:00
|
|
|
}
|
|
|
|
|
|
2019-11-18 10:38:35 -08:00
|
|
|
static ParseResult parseVectorExtractElementOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result) {
|
2019-08-09 05:58:19 -07:00
|
|
|
llvm::SMLoc attributeLoc, typeLoc;
|
|
|
|
|
SmallVector<NamedAttribute, 4> attrs;
|
|
|
|
|
OpAsmParser::OperandType vector;
|
|
|
|
|
Type type;
|
|
|
|
|
Attribute attr;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
|
|
|
|
|
parser.parseAttribute(attr, "position", attrs) ||
|
2019-11-05 13:32:07 -08:00
|
|
|
parser.parseOptionalAttrDict(attrs) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
|
2019-08-09 05:58:19 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto vectorType = type.dyn_cast<VectorType>();
|
|
|
|
|
if (!vectorType)
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(typeLoc, "expected vector type");
|
2019-08-09 05:58:19 -07:00
|
|
|
|
|
|
|
|
auto positionAttr = attr.dyn_cast<ArrayAttr>();
|
|
|
|
|
if (!positionAttr ||
|
|
|
|
|
static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(
|
2019-08-09 05:58:19 -07:00
|
|
|
attributeLoc,
|
|
|
|
|
"expected position attribute of rank smaller than vector");
|
|
|
|
|
|
2019-11-18 10:38:35 -08:00
|
|
|
Type resType = inferExtractOpResultType(vectorType, positionAttr);
|
2019-09-20 19:47:05 -07:00
|
|
|
result.attributes = attrs;
|
|
|
|
|
return failure(parser.resolveOperand(vector, type, result.operands) ||
|
|
|
|
|
parser.addTypeToList(resType, result.types));
|
2019-08-09 05:58:19 -07:00
|
|
|
}
|
|
|
|
|
|
2019-11-18 10:38:35 -08:00
|
|
|
static LogicalResult verify(VectorExtractElementOp op) {
|
2019-08-09 05:58:19 -07:00
|
|
|
auto positionAttr = op.position().getValue();
|
|
|
|
|
if (positionAttr.empty())
|
|
|
|
|
return op.emitOpError("expected non-empty position attribute");
|
|
|
|
|
if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"expected position attribute of rank smaller than vector");
|
|
|
|
|
for (auto en : llvm::enumerate(positionAttr)) {
|
|
|
|
|
auto attr = en.value().dyn_cast<IntegerAttr>();
|
|
|
|
|
if (!attr || attr.getInt() < 0 ||
|
|
|
|
|
attr.getInt() > op.getVectorType().getDimSize(en.index()))
|
|
|
|
|
return op.emitOpError("expected position attribute #")
|
|
|
|
|
<< (en.index() + 1)
|
|
|
|
|
<< " to be a positive integer smaller than the corresponding "
|
|
|
|
|
"vector dimension";
|
|
|
|
|
}
|
|
|
|
|
return success();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
2019-11-18 10:38:35 -08:00
|
|
|
|
2019-08-09 06:55:10 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-18 10:38:35 -08:00
|
|
|
// VectorOuterProductOp
|
2019-08-09 06:55:10 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-11-18 10:38:35 -08:00
|
|
|
static void print(OpAsmPrinter &p, VectorOuterProductOp op) {
|
2019-09-20 20:43:02 -07:00
|
|
|
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (llvm::size(op.acc()) > 0)
|
2019-09-20 20:43:02 -07:00
|
|
|
p << ", " << **op.acc().begin();
|
|
|
|
|
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
|
2019-08-09 06:55:10 -07:00
|
|
|
}
|
|
|
|
|
|
2019-11-18 10:38:35 -08:00
|
|
|
static ParseResult parseVectorOuterProductOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result) {
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
|
|
|
|
|
Type tLHS, tRHS;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
|
|
|
|
|
parser.parseComma() || parser.parseType(tRHS))
|
2019-08-09 06:55:10 -07:00
|
|
|
return failure();
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (operandsInfo.size() < 2)
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(parser.getNameLoc(),
|
|
|
|
|
"expected at least 2 operands");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
VectorType vLHS = tLHS.dyn_cast<VectorType>();
|
|
|
|
|
VectorType vRHS = tRHS.dyn_cast<VectorType>();
|
|
|
|
|
if (!vLHS || !vRHS)
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(parser.getNameLoc(), "expected 2 vector types");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
|
|
|
|
|
vLHS.getElementType());
|
|
|
|
|
return failure(
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
|
|
|
|
|
parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
(operandsInfo.size() > 2 &&
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
|
|
|
|
|
parser.addTypeToList(resType, result.types));
|
2019-08-09 06:55:10 -07:00
|
|
|
}
|
2018-12-14 09:31:17 -08:00
|
|
|
|
2019-11-18 10:38:35 -08:00
|
|
|
static LogicalResult verify(VectorOuterProductOp op) {
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
VectorType vLHS = op.getOperandVectorTypeLHS(),
|
|
|
|
|
vRHS = op.getOperandVectorTypeRHS(),
|
|
|
|
|
vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
|
|
|
|
|
if (vLHS.getRank() != 1)
|
2019-08-09 06:55:10 -07:00
|
|
|
return op.emitOpError("expected 1-d vector for operand #1");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (vRHS.getRank() != 1)
|
2019-08-09 06:55:10 -07:00
|
|
|
return op.emitOpError("expected 1-d vector for operand #2");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (vRES.getRank() != 2)
|
2019-08-09 06:55:10 -07:00
|
|
|
return op.emitOpError("expected 2-d vector result");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (vLHS.getDimSize(0) != vRES.getDimSize(0))
|
|
|
|
|
return op.emitOpError("expected #1 operand dim to match result dim #1");
|
|
|
|
|
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
|
|
|
|
|
return op.emitOpError("expected #2 operand dim to match result dim #2");
|
|
|
|
|
if (vACC && vACC != vRES)
|
|
|
|
|
return op.emitOpError("expected operand #3 of same type as result type");
|
2019-08-09 06:55:10 -07:00
|
|
|
return success();
|
|
|
|
|
}
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
|
2018-12-14 09:31:17 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// VectorTransferReadOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
template <typename EmitFun>
|
2019-04-02 13:09:34 -07:00
|
|
|
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
|
|
|
|
|
EmitFun emitOpError) {
|
2018-12-14 09:31:17 -08:00
|
|
|
SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
|
|
|
|
|
for (auto expr : permutationMap.getResults()) {
|
|
|
|
|
auto dim = expr.dyn_cast<AffineDimExpr>();
|
|
|
|
|
auto zero = expr.dyn_cast<AffineConstantExpr>();
|
|
|
|
|
if (zero) {
|
|
|
|
|
if (zero.getValue() != 0) {
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires a projected permutation_map (at most one dim or the zero "
|
|
|
|
|
"constant can appear in each result)");
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (!dim) {
|
|
|
|
|
return emitOpError("requires a projected permutation_map (at most one "
|
|
|
|
|
"dim or the zero constant can appear in each result)");
|
|
|
|
|
}
|
|
|
|
|
if (seen[dim.getPosition()]) {
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires a permutation_map that is a permutation (found one dim "
|
|
|
|
|
"used more than once)");
|
|
|
|
|
}
|
|
|
|
|
seen[dim.getPosition()] = true;
|
|
|
|
|
}
|
2019-04-02 13:09:34 -07:00
|
|
|
return success();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
static void print(OpAsmPrinter &p, VectorTransferReadOp op) {
|
|
|
|
|
p << op.getOperationName() << " ";
|
|
|
|
|
p.printOperand(op.memref());
|
2019-09-20 20:43:02 -07:00
|
|
|
p << "[";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printOperands(op.indices());
|
|
|
|
|
p << "], ";
|
|
|
|
|
p.printOperand(op.padding());
|
|
|
|
|
p << " ";
|
|
|
|
|
p.printOptionalAttrDict(op.getAttrs());
|
|
|
|
|
p << " : " << op.getMemRefType();
|
|
|
|
|
p << ", " << op.getVectorType();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
ParseResult parseVectorTransferReadOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result) {
|
|
|
|
|
llvm::SMLoc typesLoc;
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
|
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
2019-11-14 08:10:36 -08:00
|
|
|
OpAsmParser::OperandType paddingInfo;
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
SmallVector<Type, 2> types;
|
2018-12-14 09:31:17 -08:00
|
|
|
// Parsing with support for optional paddingValue.
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(memrefInfo) ||
|
|
|
|
|
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.parseComma() || parser.parseOperand(paddingInfo) ||
|
2019-11-05 13:32:07 -08:00
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure();
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
if (types.size() != 2)
|
2019-11-14 08:10:36 -08:00
|
|
|
return parser.emitError(typesLoc, "two types required");
|
2019-09-20 11:36:49 -07:00
|
|
|
auto indexType = parser.getBuilder().getIndexType();
|
2019-11-14 08:10:36 -08:00
|
|
|
MemRefType memRefType = types[0].dyn_cast<MemRefType>();
|
|
|
|
|
if (!memRefType)
|
|
|
|
|
return parser.emitError(typesLoc, "memref type required"), failure();
|
|
|
|
|
Type vectorType = types[1];
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure(
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperands(indexInfo, indexType, result.operands) ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.resolveOperand(paddingInfo, memRefType.getElementType(),
|
|
|
|
|
result.operands) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.addTypeToList(vectorType, result.types));
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
static LogicalResult verify(VectorTransferReadOp op) {
|
2018-12-14 09:31:17 -08:00
|
|
|
// Consistency of elemental types in memref and vector.
|
2019-11-14 08:10:36 -08:00
|
|
|
MemRefType memrefType = op.getMemRefType();
|
|
|
|
|
VectorType vectorType = op.getVectorType();
|
2018-12-14 09:31:17 -08:00
|
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
2019-11-14 08:10:36 -08:00
|
|
|
return op.emitOpError(
|
2018-12-14 09:31:17 -08:00
|
|
|
"requires memref and vector types of the same elemental type");
|
2019-11-14 08:10:36 -08:00
|
|
|
auto elementalType = op.padding()->getType();
|
|
|
|
|
if (!VectorType::isValidElementType(elementalType))
|
|
|
|
|
return op.emitOpError("requires valid padding vector elemental type");
|
|
|
|
|
if (elementalType != vectorType.getElementType())
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"requires formal padding and vector of the same elemental type");
|
|
|
|
|
if (llvm::size(op.indices()) != memrefType.getRank())
|
|
|
|
|
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
|
|
|
|
|
auto permutationMap = op.permutation_map();
|
|
|
|
|
if (permutationMap.getNumSymbols() != 0)
|
|
|
|
|
return op.emitOpError("requires permutation_map without symbols");
|
|
|
|
|
if (permutationMap.getNumInputs() != memrefType.getRank())
|
|
|
|
|
return op.emitOpError("requires a permutation_map with input dims of the "
|
|
|
|
|
"same rank as the memref type");
|
|
|
|
|
if (permutationMap.getNumResults() != vectorType.getRank())
|
|
|
|
|
return op.emitOpError("requires a permutation_map with result dims of the "
|
|
|
|
|
"same rank as the vector type");
|
2018-12-14 09:31:17 -08:00
|
|
|
return verifyPermutationMap(permutationMap,
|
2019-11-14 08:10:36 -08:00
|
|
|
[&op](Twine t) { return op.emitOpError(t); });
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// VectorTransferWriteOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-14 08:10:36 -08:00
|
|
|
static void print(OpAsmPrinter &p, VectorTransferWriteOp op) {
|
|
|
|
|
p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref();
|
2019-09-20 20:43:02 -07:00
|
|
|
p << "[";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printOperands(op.indices());
|
2019-09-20 20:43:02 -07:00
|
|
|
p << "]";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printOptionalAttrDict(op.getAttrs());
|
2019-09-20 20:43:02 -07:00
|
|
|
p << " : ";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printType(op.getVectorType());
|
2019-09-20 20:43:02 -07:00
|
|
|
p << ", ";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printType(op.getMemRefType());
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
ParseResult parseVectorTransferWriteOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result) {
|
|
|
|
|
llvm::SMLoc typesLoc;
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
OpAsmParser::OperandType storeValueInfo;
|
2019-11-14 08:10:36 -08:00
|
|
|
OpAsmParser::OperandType memRefInfo;
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
|
|
|
SmallVector<Type, 2> types;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(storeValueInfo) || parser.parseComma() ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.parseOperand(memRefInfo) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
2019-11-05 13:32:07 -08:00
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure();
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
if (types.size() != 2)
|
2019-11-14 08:10:36 -08:00
|
|
|
return parser.emitError(typesLoc, "two types required");
|
|
|
|
|
auto indexType = parser.getBuilder().getIndexType();
|
|
|
|
|
Type vectorType = types[0], memRefType = types[1];
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure(
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.resolveOperand(storeValueInfo, vectorType, result.operands) ||
|
|
|
|
|
parser.resolveOperand(memRefInfo, memRefType, result.operands) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperands(indexInfo, indexType, result.operands));
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
static LogicalResult verify(VectorTransferWriteOp op) {
|
2018-12-14 09:31:17 -08:00
|
|
|
// Consistency of elemental types in memref and vector.
|
2019-11-14 08:10:36 -08:00
|
|
|
MemRefType memrefType = op.getMemRefType();
|
|
|
|
|
VectorType vectorType = op.getVectorType();
|
2018-12-14 09:31:17 -08:00
|
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
2019-11-14 08:10:36 -08:00
|
|
|
return op.emitOpError(
|
2018-12-14 09:31:17 -08:00
|
|
|
"requires memref and vector types of the same elemental type");
|
2019-11-14 08:10:36 -08:00
|
|
|
if (llvm::size(op.indices()) != memrefType.getRank())
|
|
|
|
|
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
|
2018-12-14 09:31:17 -08:00
|
|
|
|
|
|
|
|
// Consistency of AffineMap attribute.
|
2019-11-14 08:10:36 -08:00
|
|
|
auto permutationMap = op.permutation_map();
|
|
|
|
|
if (permutationMap.getNumSymbols() != 0)
|
|
|
|
|
return op.emitOpError("requires a symbol-less permutation_map");
|
|
|
|
|
if (permutationMap.getNumInputs() != memrefType.getRank())
|
|
|
|
|
return op.emitOpError("requires a permutation_map with input dims of the "
|
|
|
|
|
"same rank as the memref type: ")
|
|
|
|
|
<< permutationMap.getNumInputs() << " vs " << memrefType;
|
|
|
|
|
if (permutationMap.getNumResults() != vectorType.getRank())
|
|
|
|
|
return op.emitOpError("requires a permutation_map with result dims of the "
|
|
|
|
|
"same rank as the vector type.")
|
|
|
|
|
<< permutationMap.getNumResults() << " vs " << vectorType;
|
2018-12-14 09:31:17 -08:00
|
|
|
return verifyPermutationMap(permutationMap,
|
2019-11-14 08:10:36 -08:00
|
|
|
[&op](Twine t) { return op.emitOpError(t); });
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
2018-12-17 14:10:52 -08:00
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// VectorTypeCastOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
static MemRefType inferVectorTypeCastResultType(MemRefType t) {
|
|
|
|
|
return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType()));
|
2018-12-17 14:10:52 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
void VectorTypeCastOp::build(Builder *builder, OperationState &result,
|
|
|
|
|
Value *source) {
|
|
|
|
|
result.addOperands(source);
|
|
|
|
|
result.addTypes(
|
|
|
|
|
inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
|
2018-12-17 14:10:52 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
static void print(OpAsmPrinter &p, VectorTypeCastOp &op) {
|
|
|
|
|
auto type = op.getOperand()->getType().cast<MemRefType>();
|
|
|
|
|
p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to "
|
|
|
|
|
<< inferVectorTypeCastResultType(type);
|
|
|
|
|
}
|
2018-12-17 14:10:52 -08:00
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
static LogicalResult verify(VectorTypeCastOp &op) {
|
|
|
|
|
auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
|
|
|
|
|
if (op.getResultMemRefType() != resultType)
|
|
|
|
|
return op.emitOpError("expects result type to be: ") << resultType;
|
2019-04-02 13:09:34 -07:00
|
|
|
return success();
|
2018-12-17 14:10:52 -08:00
|
|
|
}
|
2019-08-09 05:58:19 -07:00
|
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
|
2019-08-09 05:58:19 -07:00
|
|
|
|
|
|
|
|
} // namespace mlir
|