Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
//===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===//
|
2018-12-14 09:31:17 -08:00
|
|
|
//
|
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
|
//
|
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
|
//
|
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
//
|
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
// =============================================================================
|
|
|
|
|
//
|
|
|
|
|
// This file implements convenience types for working with super-vectorization
|
|
|
|
|
// operations, in particular super-vector loads and stores.
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
2018-12-14 09:31:17 -08:00
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2019-08-09 05:58:19 -07:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2018-12-14 09:31:17 -08:00
|
|
|
#include "mlir/Support/LLVM.h"
|
2019-08-09 05:58:19 -07:00
|
|
|
|
2018-12-14 09:31:17 -08:00
|
|
|
using namespace mlir;
|
2019-08-09 05:58:19 -07:00
|
|
|
using namespace mlir::vector;
|
2018-12-14 09:31:17 -08:00
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
// VectorOpsDialect
|
2018-12-14 09:31:17 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-08-09 05:58:19 -07:00
|
|
|
mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
|
|
|
|
|
: Dialect(getDialectNamespace(), context) {
|
2018-12-17 14:10:52 -08:00
|
|
|
addOperations<VectorTransferReadOp, VectorTransferWriteOp,
|
|
|
|
|
VectorTypeCastOp>();
|
2019-08-09 05:58:19 -07:00
|
|
|
addOperations<
|
|
|
|
|
#define GET_OP_LIST
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
|
2019-08-09 05:58:19 -07:00
|
|
|
>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// ExtractElementOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
static void print(OpAsmPrinter *p, ExtractElementOp op) {
|
|
|
|
|
*p << op.getOperationName() << " " << *op.vector() << op.position();
|
|
|
|
|
p->printOptionalAttrDict(op.getAttrs(), {"position"});
|
|
|
|
|
*p << " : " << op.vector()->getType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ParseResult parseExtractElementOp(OpAsmParser *parser,
|
|
|
|
|
OperationState *result) {
|
|
|
|
|
llvm::SMLoc attributeLoc, typeLoc;
|
|
|
|
|
SmallVector<NamedAttribute, 4> attrs;
|
|
|
|
|
OpAsmParser::OperandType vector;
|
|
|
|
|
Type type;
|
|
|
|
|
Attribute attr;
|
|
|
|
|
if (parser->parseOperand(vector) ||
|
|
|
|
|
parser->getCurrentLocation(&attributeLoc) ||
|
|
|
|
|
parser->parseAttribute(attr, "position", attrs) ||
|
|
|
|
|
parser->parseOptionalAttributeDict(attrs) ||
|
|
|
|
|
parser->getCurrentLocation(&typeLoc) || parser->parseColonType(type))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto vectorType = type.dyn_cast<VectorType>();
|
|
|
|
|
if (!vectorType)
|
|
|
|
|
return parser->emitError(typeLoc, "expected vector type");
|
|
|
|
|
|
|
|
|
|
auto positionAttr = attr.dyn_cast<ArrayAttr>();
|
|
|
|
|
if (!positionAttr ||
|
|
|
|
|
static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
|
|
|
|
|
return parser->emitError(
|
|
|
|
|
attributeLoc,
|
|
|
|
|
"expected position attribute of rank smaller than vector");
|
|
|
|
|
|
|
|
|
|
Type resType =
|
|
|
|
|
(static_cast<int64_t>(positionAttr.size()) == vectorType.getRank())
|
|
|
|
|
? vectorType.getElementType()
|
|
|
|
|
: VectorType::get(
|
|
|
|
|
vectorType.getShape().drop_front(positionAttr.size()),
|
|
|
|
|
vectorType.getElementType());
|
|
|
|
|
|
|
|
|
|
result->attributes = attrs;
|
|
|
|
|
return failure(parser->resolveOperand(vector, type, result->operands) ||
|
|
|
|
|
parser->addTypeToList(resType, result->types));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static LogicalResult verify(ExtractElementOp op) {
|
|
|
|
|
auto positionAttr = op.position().getValue();
|
|
|
|
|
if (positionAttr.empty())
|
|
|
|
|
return op.emitOpError("expected non-empty position attribute");
|
|
|
|
|
if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"expected position attribute of rank smaller than vector");
|
|
|
|
|
for (auto en : llvm::enumerate(positionAttr)) {
|
|
|
|
|
auto attr = en.value().dyn_cast<IntegerAttr>();
|
|
|
|
|
if (!attr || attr.getInt() < 0 ||
|
|
|
|
|
attr.getInt() > op.getVectorType().getDimSize(en.index()))
|
|
|
|
|
return op.emitOpError("expected position attribute #")
|
|
|
|
|
<< (en.index() + 1)
|
|
|
|
|
<< " to be a positive integer smaller than the corresponding "
|
|
|
|
|
"vector dimension";
|
|
|
|
|
}
|
|
|
|
|
return success();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
2019-08-09 06:55:10 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// OuterProductOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
static void print(OpAsmPrinter *p, OuterProductOp op) {
|
|
|
|
|
*p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (llvm::size(op.acc()) > 0)
|
|
|
|
|
*p << ", " << **op.acc().begin();
|
2019-08-09 06:55:10 -07:00
|
|
|
*p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ParseResult parseOuterProductOp(OpAsmParser *parser,
|
|
|
|
|
OperationState *result) {
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
|
|
|
|
|
Type tLHS, tRHS;
|
|
|
|
|
if (parser->parseOperandList(operandsInfo) || parser->parseColonType(tLHS) ||
|
|
|
|
|
parser->parseComma() || parser->parseType(tRHS))
|
2019-08-09 06:55:10 -07:00
|
|
|
return failure();
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (operandsInfo.size() < 2)
|
|
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
|
"expected at least 2 operands");
|
|
|
|
|
VectorType vLHS = tLHS.dyn_cast<VectorType>();
|
|
|
|
|
VectorType vRHS = tRHS.dyn_cast<VectorType>();
|
|
|
|
|
if (!vLHS || !vRHS)
|
2019-08-09 06:55:10 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(), "expected 2 vector types");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
|
|
|
|
|
vLHS.getElementType());
|
|
|
|
|
return failure(
|
|
|
|
|
parser->resolveOperand(operandsInfo[0], tLHS, result->operands) ||
|
|
|
|
|
parser->resolveOperand(operandsInfo[1], tRHS, result->operands) ||
|
|
|
|
|
(operandsInfo.size() > 2 &&
|
|
|
|
|
parser->resolveOperand(operandsInfo[2], resType, result->operands)) ||
|
|
|
|
|
parser->addTypeToList(resType, result->types));
|
2019-08-09 06:55:10 -07:00
|
|
|
}
|
2018-12-14 09:31:17 -08:00
|
|
|
|
2019-08-09 06:55:10 -07:00
|
|
|
static LogicalResult verify(OuterProductOp op) {
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
VectorType vLHS = op.getOperandVectorTypeLHS(),
|
|
|
|
|
vRHS = op.getOperandVectorTypeRHS(),
|
|
|
|
|
vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
|
|
|
|
|
if (vLHS.getRank() != 1)
|
2019-08-09 06:55:10 -07:00
|
|
|
return op.emitOpError("expected 1-d vector for operand #1");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (vRHS.getRank() != 1)
|
2019-08-09 06:55:10 -07:00
|
|
|
return op.emitOpError("expected 1-d vector for operand #2");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (vRES.getRank() != 2)
|
2019-08-09 06:55:10 -07:00
|
|
|
return op.emitOpError("expected 2-d vector result");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (vLHS.getDimSize(0) != vRES.getDimSize(0))
|
|
|
|
|
return op.emitOpError("expected #1 operand dim to match result dim #1");
|
|
|
|
|
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
|
|
|
|
|
return op.emitOpError("expected #2 operand dim to match result dim #2");
|
|
|
|
|
if (vACC && vACC != vRES)
|
|
|
|
|
return op.emitOpError("expected operand #3 of same type as result type");
|
2019-08-09 06:55:10 -07:00
|
|
|
return success();
|
|
|
|
|
}
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
|
2018-12-14 09:31:17 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// VectorTransferReadOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
template <typename EmitFun>
|
2019-04-02 13:09:34 -07:00
|
|
|
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
|
|
|
|
|
EmitFun emitOpError) {
|
2018-12-14 09:31:17 -08:00
|
|
|
SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
|
|
|
|
|
for (auto expr : permutationMap.getResults()) {
|
|
|
|
|
auto dim = expr.dyn_cast<AffineDimExpr>();
|
|
|
|
|
auto zero = expr.dyn_cast<AffineConstantExpr>();
|
|
|
|
|
if (zero) {
|
|
|
|
|
if (zero.getValue() != 0) {
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires a projected permutation_map (at most one dim or the zero "
|
|
|
|
|
"constant can appear in each result)");
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (!dim) {
|
|
|
|
|
return emitOpError("requires a projected permutation_map (at most one "
|
|
|
|
|
"dim or the zero constant can appear in each result)");
|
|
|
|
|
}
|
|
|
|
|
if (seen[dim.getPosition()]) {
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires a permutation_map that is a permutation (found one dim "
|
|
|
|
|
"used more than once)");
|
|
|
|
|
}
|
|
|
|
|
seen[dim.getPosition()] = true;
|
|
|
|
|
}
|
2019-04-02 13:09:34 -07:00
|
|
|
return success();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VectorTransferReadOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
VectorType vectorType, Value *srcMemRef,
|
|
|
|
|
ArrayRef<Value *> srcIndices,
|
2018-12-14 09:31:17 -08:00
|
|
|
AffineMap permutationMap,
|
2018-12-27 14:35:10 -08:00
|
|
|
Optional<Value *> paddingValue) {
|
2018-12-14 09:31:17 -08:00
|
|
|
result->addOperands(srcMemRef);
|
|
|
|
|
result->addOperands(srcIndices);
|
|
|
|
|
if (paddingValue) {
|
|
|
|
|
result->addOperands({*paddingValue});
|
|
|
|
|
}
|
|
|
|
|
result->addAttribute(getPermutationMapAttrName(),
|
|
|
|
|
builder->getAffineMapAttr(permutationMap));
|
|
|
|
|
result->addTypes(vectorType);
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-04 16:24:10 -07:00
|
|
|
auto VectorTransferReadOp::getIndices() -> operand_range {
|
2019-03-26 17:05:09 -07:00
|
|
|
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
|
2018-12-14 09:31:17 -08:00
|
|
|
auto end = begin + getMemRefType().getRank();
|
|
|
|
|
return {begin, end};
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-27 14:35:10 -08:00
|
|
|
Optional<Value *> VectorTransferReadOp::getPaddingValue() {
|
2018-12-14 09:31:17 -08:00
|
|
|
auto memRefRank = getMemRefType().getRank();
|
|
|
|
|
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
|
|
|
|
|
return None;
|
|
|
|
|
}
|
2018-12-27 14:35:10 -08:00
|
|
|
return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank));
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-03-23 15:09:06 -07:00
|
|
|
AffineMap VectorTransferReadOp::getPermutationMap() {
|
2018-12-14 09:31:17 -08:00
|
|
|
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
|
|
|
|
|
}
|
|
|
|
|
|
2019-03-23 09:03:07 -07:00
|
|
|
void VectorTransferReadOp::print(OpAsmPrinter *p) {
|
2018-12-14 09:31:17 -08:00
|
|
|
*p << getOperationName() << " ";
|
|
|
|
|
p->printOperand(getMemRef());
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
*p << "[";
|
2018-12-14 09:31:17 -08:00
|
|
|
p->printOperands(getIndices());
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
*p << "]";
|
2018-12-14 09:31:17 -08:00
|
|
|
auto optionalPaddingValue = getPaddingValue();
|
|
|
|
|
if (optionalPaddingValue) {
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
*p << ", (";
|
2018-12-14 09:31:17 -08:00
|
|
|
p->printOperand(*optionalPaddingValue);
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
*p << ")";
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
p->printOptionalAttrDict(getAttrs());
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
*p << " : " << getMemRefType();
|
|
|
|
|
*p << ", " << getResultType();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-05-06 22:01:31 -07:00
|
|
|
ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
|
|
|
|
|
OperationState *result) {
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
|
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
|
|
|
|
SmallVector<OpAsmParser::OperandType, 8> paddingInfo;
|
|
|
|
|
SmallVector<Type, 2> types;
|
2018-12-14 09:31:17 -08:00
|
|
|
|
|
|
|
|
// Parsing with support for optional paddingValue.
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
if (parser->parseOperand(memrefInfo) ||
|
2019-06-04 23:33:18 -07:00
|
|
|
parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
|
|
|
|
parser->parseTrailingOperandList(paddingInfo,
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
OpAsmParser::Delimiter::Paren) ||
|
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
|
parser->parseColonTypeList(types))
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure();
|
2018-12-14 09:31:17 -08:00
|
|
|
|
|
|
|
|
// Resolution.
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
if (types.size() != 2)
|
|
|
|
|
return parser->emitError(parser->getNameLoc(), "expected 2 types");
|
|
|
|
|
MemRefType memrefType = types[0].dyn_cast<MemRefType>();
|
2018-12-14 09:31:17 -08:00
|
|
|
if (!memrefType)
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(), "memRef type expected");
|
|
|
|
|
VectorType vectorType = types[1].dyn_cast<VectorType>();
|
2018-12-14 09:31:17 -08:00
|
|
|
if (!vectorType)
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(), "vector type expected");
|
2018-12-14 09:31:17 -08:00
|
|
|
|
|
|
|
|
// Extract optional paddingValue.
|
2019-08-09 05:58:19 -07:00
|
|
|
// At this point, indexInfo may contain the optional paddingValue, pop it
|
|
|
|
|
// out.
|
2019-05-31 16:41:21 -07:00
|
|
|
if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank())
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
|
"expected " + Twine(memrefType.getRank()) +
|
|
|
|
|
" indices to the memref");
|
|
|
|
|
if (paddingInfo.size() > 1)
|
|
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
|
"expected at most one padding value");
|
2018-12-14 09:31:17 -08:00
|
|
|
Type paddingType;
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
bool hasOptionalPaddingValue = !paddingInfo.empty();
|
|
|
|
|
if (hasOptionalPaddingValue) {
|
|
|
|
|
paddingType = vectorType.getElementType();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
auto indexType = parser->getBuilder().getIndexType();
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure(
|
|
|
|
|
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
|
|
|
|
|
parser->resolveOperands(indexInfo, indexType, result->operands) ||
|
|
|
|
|
(hasOptionalPaddingValue &&
|
|
|
|
|
parser->resolveOperand(paddingInfo[0], paddingType, result->operands)) ||
|
|
|
|
|
parser->addTypeToList(vectorType, result->types));
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-04-02 13:09:34 -07:00
|
|
|
LogicalResult VectorTransferReadOp::verify() {
|
2018-12-14 09:31:17 -08:00
|
|
|
// Consistency of memref type in function type.
|
|
|
|
|
if (llvm::empty(getOperands())) {
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires at least a memref operand followed by 'rank' indices");
|
|
|
|
|
}
|
|
|
|
|
if (!getMemRef()->getType().isa<MemRefType>()) {
|
|
|
|
|
return emitOpError("requires a memref as first operand");
|
|
|
|
|
}
|
|
|
|
|
// Consistency of vector type in function type.
|
|
|
|
|
if (!getResult()->getType().isa<VectorType>()) {
|
|
|
|
|
return emitOpError("should have a vector result type in function type: "
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
"memref_type<...xelemental_type>, vector_type");
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
// Consistency of elemental types in memref and vector.
|
|
|
|
|
MemRefType memrefType = getMemRefType();
|
|
|
|
|
VectorType vectorType = getResultType();
|
|
|
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires memref and vector types of the same elemental type");
|
|
|
|
|
// Consistency of number of input types.
|
|
|
|
|
auto optionalPaddingValue = getPaddingValue();
|
|
|
|
|
unsigned expectedNumOperands = Offsets::FirstIndexOffset +
|
|
|
|
|
memrefType.getRank() +
|
|
|
|
|
(optionalPaddingValue ? 1 : 0);
|
|
|
|
|
// Checks on the actual operands and their types.
|
|
|
|
|
if (getNumOperands() != expectedNumOperands) {
|
2019-05-06 09:46:11 -07:00
|
|
|
return emitOpError("expects ")
|
|
|
|
|
<< expectedNumOperands << " operands (of which "
|
|
|
|
|
<< memrefType.getRank() << " indices)";
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
// Consistency of padding value with vector type.
|
|
|
|
|
if (optionalPaddingValue) {
|
|
|
|
|
auto paddingValue = *optionalPaddingValue;
|
|
|
|
|
auto elementalType = paddingValue->getType();
|
|
|
|
|
if (!VectorType::isValidElementType(elementalType)) {
|
|
|
|
|
return emitOpError("requires valid padding vector elemental type");
|
|
|
|
|
}
|
|
|
|
|
if (elementalType != vectorType.getElementType()) {
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires formal padding and vector of the same elemental type");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Consistency of indices types.
|
|
|
|
|
unsigned numIndices = 0;
|
|
|
|
|
for (auto *idx : getIndices()) {
|
|
|
|
|
if (!idx->getType().isIndex()) {
|
|
|
|
|
return emitOpError(
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
"index to vector.transfer_read must have 'index' type");
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
++numIndices;
|
|
|
|
|
}
|
|
|
|
|
if (numIndices != memrefType.getRank()) {
|
2019-05-06 09:46:11 -07:00
|
|
|
return emitOpError("requires at least a memref operand followed by ")
|
|
|
|
|
<< memrefType.getRank() << " indices";
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Consistency of AffineMap attribute.
|
|
|
|
|
if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
|
|
|
|
|
return emitOpError("requires an AffineMapAttr named 'permutation_map'");
|
|
|
|
|
}
|
|
|
|
|
auto permutationMap = getPermutationMap();
|
|
|
|
|
if (permutationMap.getNumSymbols() != 0) {
|
|
|
|
|
return emitOpError("requires a permutation_map without symbols");
|
|
|
|
|
}
|
|
|
|
|
if (permutationMap.getNumInputs() != memrefType.getRank()) {
|
|
|
|
|
return emitOpError("requires a permutation_map with input dims of the "
|
|
|
|
|
"same rank as the memref type");
|
|
|
|
|
}
|
|
|
|
|
if (permutationMap.getNumResults() != vectorType.getRank()) {
|
|
|
|
|
return emitOpError("requires a permutation_map with result dims of the "
|
2019-05-06 09:46:11 -07:00
|
|
|
"same rank as the vector type (")
|
|
|
|
|
<< permutationMap.getNumResults() << " vs " << vectorType.getRank();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
return verifyPermutationMap(permutationMap,
|
|
|
|
|
[this](Twine t) { return emitOpError(t); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// VectorTransferWriteOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
Value *srcVector, Value *dstMemRef,
|
|
|
|
|
ArrayRef<Value *> dstIndices,
|
2018-12-14 09:31:17 -08:00
|
|
|
AffineMap permutationMap) {
|
|
|
|
|
result->addOperands({srcVector, dstMemRef});
|
|
|
|
|
result->addOperands(dstIndices);
|
|
|
|
|
result->addAttribute(getPermutationMapAttrName(),
|
|
|
|
|
builder->getAffineMapAttr(permutationMap));
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-04 16:24:10 -07:00
|
|
|
auto VectorTransferWriteOp::getIndices() -> operand_range {
|
2019-03-26 17:05:09 -07:00
|
|
|
auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
|
2018-12-14 09:31:17 -08:00
|
|
|
auto end = begin + getMemRefType().getRank();
|
|
|
|
|
return {begin, end};
|
|
|
|
|
}
|
|
|
|
|
|
2019-03-23 15:09:06 -07:00
|
|
|
AffineMap VectorTransferWriteOp::getPermutationMap() {
|
2018-12-14 09:31:17 -08:00
|
|
|
return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
|
|
|
|
|
}
|
|
|
|
|
|
2019-03-23 09:03:07 -07:00
|
|
|
void VectorTransferWriteOp::print(OpAsmPrinter *p) {
|
2018-12-14 09:31:17 -08:00
|
|
|
*p << getOperationName();
|
|
|
|
|
*p << " " << *getVector();
|
|
|
|
|
*p << ", " << *getMemRef();
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
*p << "[";
|
2018-12-14 09:31:17 -08:00
|
|
|
p->printOperands(getIndices());
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
*p << "]";
|
2018-12-14 09:31:17 -08:00
|
|
|
p->printOptionalAttrDict(getAttrs());
|
|
|
|
|
*p << " : ";
|
|
|
|
|
p->printType(getVectorType());
|
|
|
|
|
*p << ", ";
|
|
|
|
|
p->printType(getMemRefType());
|
|
|
|
|
}
|
|
|
|
|
|
2019-05-06 22:01:31 -07:00
|
|
|
ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser,
|
|
|
|
|
OperationState *result) {
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
OpAsmParser::OperandType storeValueInfo;
|
|
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
|
|
|
SmallVector<Type, 2> types;
|
|
|
|
|
auto indexType = parser->getBuilder().getIndexType();
|
|
|
|
|
if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
|
|
|
|
|
parser->parseOperand(memrefInfo) ||
|
2019-06-04 23:33:18 -07:00
|
|
|
parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
|
parser->parseColonTypeList(types))
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure();
|
2018-12-14 09:31:17 -08:00
|
|
|
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
if (types.size() != 2)
|
|
|
|
|
return parser->emitError(parser->getNameLoc(), "expected 2 types");
|
2018-12-14 09:31:17 -08:00
|
|
|
VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
|
|
|
|
|
if (!vectorType)
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(), "vector type expected");
|
2018-12-14 09:31:17 -08:00
|
|
|
MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
|
|
|
|
|
if (!memrefType)
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
return parser->emitError(parser->getNameLoc(), "memRef type expected");
|
2018-12-14 09:31:17 -08:00
|
|
|
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure(
|
|
|
|
|
parser->resolveOperands(storeValueInfo, vectorType, result->operands) ||
|
|
|
|
|
parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
|
|
|
|
|
parser->resolveOperands(indexInfo, indexType, result->operands));
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-04-02 13:09:34 -07:00
|
|
|
LogicalResult VectorTransferWriteOp::verify() {
|
2018-12-14 09:31:17 -08:00
|
|
|
// Consistency of memref type in function type.
|
|
|
|
|
if (llvm::empty(getOperands())) {
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires at least a memref operand followed by 'rank' indices");
|
|
|
|
|
}
|
|
|
|
|
if (!getMemRef()->getType().isa<MemRefType>()) {
|
|
|
|
|
return emitOpError("requires a memref first operand");
|
|
|
|
|
}
|
|
|
|
|
// Consistency of vector type in function type.
|
|
|
|
|
if (!getVector()->getType().isa<VectorType>()) {
|
|
|
|
|
return emitOpError("should have a vector input type in function type: "
|
|
|
|
|
"(vector_type, memref_type [, elemental_type]) -> ()");
|
|
|
|
|
}
|
|
|
|
|
// Consistency of elemental types in memref and vector.
|
|
|
|
|
MemRefType memrefType = getMemRefType();
|
|
|
|
|
VectorType vectorType = getVectorType();
|
|
|
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires memref and vector types of the same elemental type");
|
|
|
|
|
// Consistency of number of input types.
|
|
|
|
|
unsigned expectedNumOperands =
|
|
|
|
|
Offsets::FirstIndexOffset + memrefType.getRank();
|
|
|
|
|
// Checks on the actual operands and their types.
|
|
|
|
|
if (getNumOperands() != expectedNumOperands) {
|
2019-05-06 09:46:11 -07:00
|
|
|
return emitOpError() << "expects " << expectedNumOperands
|
|
|
|
|
<< " operands (of which " << memrefType.getRank()
|
|
|
|
|
<< " indices)";
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
// Consistency of indices types.
|
|
|
|
|
unsigned numIndices = 0;
|
|
|
|
|
for (auto *idx : getIndices()) {
|
|
|
|
|
if (!idx->getType().isIndex()) {
|
|
|
|
|
return emitOpError(
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
"index to vector.transfer_write must have 'index' type");
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
numIndices++;
|
|
|
|
|
}
|
|
|
|
|
if (numIndices != memrefType.getRank()) {
|
2019-05-06 09:46:11 -07:00
|
|
|
return emitOpError("requires at least a memref operand followed by ")
|
|
|
|
|
<< memrefType.getRank() << " indices";
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Consistency of AffineMap attribute.
|
|
|
|
|
if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
|
|
|
|
|
return emitOpError("requires an AffineMapAttr named 'permutation_map'");
|
|
|
|
|
}
|
|
|
|
|
auto permutationMap = getPermutationMap();
|
|
|
|
|
if (permutationMap.getNumSymbols() != 0) {
|
|
|
|
|
return emitOpError("requires a permutation_map without symbols");
|
|
|
|
|
}
|
|
|
|
|
if (permutationMap.getNumInputs() != memrefType.getRank()) {
|
|
|
|
|
return emitOpError("requires a permutation_map with input dims of the "
|
|
|
|
|
"same rank as the memref type");
|
|
|
|
|
}
|
|
|
|
|
if (permutationMap.getNumResults() != vectorType.getRank()) {
|
|
|
|
|
return emitOpError("requires a permutation_map with result dims of the "
|
2019-05-06 09:46:11 -07:00
|
|
|
"same rank as the vector type (")
|
|
|
|
|
<< permutationMap.getNumResults() << " vs " << vectorType.getRank();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
return verifyPermutationMap(permutationMap,
|
|
|
|
|
[this](Twine t) { return emitOpError(t); });
|
|
|
|
|
}
|
2018-12-17 14:10:52 -08:00
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// VectorTypeCastOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
void VectorTypeCastOp::build(Builder *builder, OperationState *result,
|
2018-12-27 14:35:10 -08:00
|
|
|
Value *srcVector, Type dstType) {
|
2018-12-17 14:10:52 -08:00
|
|
|
result->addOperands(srcVector);
|
|
|
|
|
result->addTypes(dstType);
|
|
|
|
|
}
|
|
|
|
|
|
2019-05-06 22:01:31 -07:00
|
|
|
ParseResult VectorTypeCastOp::parse(OpAsmParser *parser,
|
|
|
|
|
OperationState *result) {
|
2018-12-17 14:10:52 -08:00
|
|
|
OpAsmParser::OperandType operand;
|
|
|
|
|
Type srcType, dstType;
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure(parser->parseOperand(operand) ||
|
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
|
parser->parseColonType(srcType) || parser->parseComma() ||
|
|
|
|
|
parser->parseType(dstType) ||
|
|
|
|
|
parser->addTypeToList(dstType, result->types) ||
|
|
|
|
|
parser->resolveOperand(operand, srcType, result->operands));
|
2018-12-17 14:10:52 -08:00
|
|
|
}
|
|
|
|
|
|
2019-03-23 09:03:07 -07:00
|
|
|
void VectorTypeCastOp::print(OpAsmPrinter *p) {
|
2018-12-17 14:10:52 -08:00
|
|
|
*p << getOperationName() << ' ' << *getOperand() << " : "
|
|
|
|
|
<< getOperand()->getType() << ", " << getType();
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-02 13:09:34 -07:00
|
|
|
LogicalResult VectorTypeCastOp::verify() {
|
2018-12-17 14:10:52 -08:00
|
|
|
auto dstMemrefType = getType().dyn_cast<MemRefType>();
|
|
|
|
|
if (!dstMemrefType)
|
|
|
|
|
return emitOpError("expects target type to be a memref type");
|
|
|
|
|
auto dstVectorType = dstMemrefType.getElementType().dyn_cast<VectorType>();
|
|
|
|
|
if (!dstVectorType)
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"expects vector as an element of the target memref type");
|
2019-01-23 14:39:45 -08:00
|
|
|
if (!dstMemrefType.hasStaticShape())
|
2018-12-17 14:10:52 -08:00
|
|
|
return emitOpError("does not support dynamic shapes");
|
|
|
|
|
|
|
|
|
|
if (!getOperand()->getType().isa<MemRefType>())
|
|
|
|
|
return emitOpError("expects source type to be a memref type");
|
|
|
|
|
|
2019-04-02 13:09:34 -07:00
|
|
|
return success();
|
2018-12-17 14:10:52 -08:00
|
|
|
}
|
2019-08-09 05:58:19 -07:00
|
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
|
2019-08-09 05:58:19 -07:00
|
|
|
|
|
|
|
|
} // namespace mlir
|