Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
//===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===//
|
2018-12-14 09:31:17 -08:00
|
|
|
//
|
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
|
//
|
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
|
//
|
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
//
|
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
// =============================================================================
|
|
|
|
|
//
|
|
|
|
|
// This file implements convenience types for working with super-vectorization
|
|
|
|
|
// operations, in particular super-vector loads and stores.
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
2019-12-04 13:00:14 -08:00
|
|
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
2018-12-14 09:31:17 -08:00
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2019-12-04 13:00:14 -08:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2019-08-09 05:58:19 -07:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2019-11-25 15:36:45 -08:00
|
|
|
#include "mlir/Support/Functional.h"
|
2018-12-14 09:31:17 -08:00
|
|
|
#include "mlir/Support/LLVM.h"
|
2019-11-25 12:39:30 -08:00
|
|
|
#include "llvm/ADT/StringSet.h"
|
2019-08-09 05:58:19 -07:00
|
|
|
|
2018-12-14 09:31:17 -08:00
|
|
|
using namespace mlir;
|
2019-08-09 05:58:19 -07:00
|
|
|
using namespace mlir::vector;
|
2018-12-14 09:31:17 -08:00
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
// VectorOpsDialect
|
2018-12-14 09:31:17 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-08-09 05:58:19 -07:00
|
|
|
mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
|
|
|
|
|
: Dialect(getDialectNamespace(), context) {
|
|
|
|
|
addOperations<
|
|
|
|
|
#define GET_OP_LIST
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
|
2019-08-09 05:58:19 -07:00
|
|
|
>();
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-20 14:43:15 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-22 07:52:02 -08:00
|
|
|
// ContractionOp
|
2019-11-20 14:43:15 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static ParseResult parseContractionOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result) {
|
2019-11-20 14:43:15 -08:00
|
|
|
OpAsmParser::OperandType lhsInfo;
|
|
|
|
|
OpAsmParser::OperandType rhsInfo;
|
|
|
|
|
OpAsmParser::OperandType accInfo;
|
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> masksInfo;
|
|
|
|
|
SmallVector<Type, 2> types;
|
|
|
|
|
Type resultVectorType;
|
|
|
|
|
auto loc = parser.getCurrentLocation();
|
2019-11-25 12:39:30 -08:00
|
|
|
DictionaryAttr dictAttr;
|
|
|
|
|
// TODO(andydavis, ntv) Unify linalg op attribute parsing.
|
|
|
|
|
if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
|
|
|
|
|
parser.parseOperand(lhsInfo) || parser.parseComma() ||
|
2019-11-20 14:43:15 -08:00
|
|
|
parser.parseOperand(rhsInfo) || parser.parseComma() ||
|
|
|
|
|
parser.parseOperand(accInfo) ||
|
|
|
|
|
parser.parseTrailingOperandList(masksInfo) ||
|
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
|
|
|
parser.parseColonTypeList(types) ||
|
|
|
|
|
parser.parseKeywordType("into", resultVectorType) ||
|
|
|
|
|
parser.resolveOperand(lhsInfo, types[0], result.operands) ||
|
|
|
|
|
parser.resolveOperand(rhsInfo, types[1], result.operands) ||
|
|
|
|
|
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
|
|
|
|
|
parser.addTypeToList(resultVectorType, result.types))
|
|
|
|
|
return failure();
|
2019-11-25 12:39:30 -08:00
|
|
|
result.attributes.assign(dictAttr.getValue().begin(),
|
|
|
|
|
dictAttr.getValue().end());
|
2019-11-20 14:43:15 -08:00
|
|
|
if (masksInfo.empty())
|
|
|
|
|
return success();
|
|
|
|
|
if (masksInfo.size() != 2)
|
|
|
|
|
return parser.emitError(parser.getNameLoc(),
|
|
|
|
|
"expected zero or exactly 2 vector mask operands");
|
|
|
|
|
auto lhsType = types[0].cast<VectorType>();
|
|
|
|
|
auto rhsType = types[1].cast<VectorType>();
|
2019-12-06 07:36:55 -08:00
|
|
|
auto maskElementType = parser.getBuilder().getI1Type();
|
2019-11-20 14:43:15 -08:00
|
|
|
SmallVector<Type, 2> maskTypes;
|
2019-12-06 07:36:55 -08:00
|
|
|
maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
|
|
|
|
|
maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
|
2019-11-20 14:43:15 -08:00
|
|
|
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
|
|
|
|
|
return failure();
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static void print(OpAsmPrinter &p, ContractionOp op) {
|
2019-11-25 12:39:30 -08:00
|
|
|
// TODO(andydavis, ntv) Unify printing code with linalg ops.
|
|
|
|
|
auto attrNames = op.getTraitAttrNames();
|
|
|
|
|
llvm::StringSet<> traitAttrsSet;
|
|
|
|
|
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
|
|
|
|
|
SmallVector<NamedAttribute, 8> attrs;
|
|
|
|
|
for (auto attr : op.getAttrs()) {
|
|
|
|
|
if (traitAttrsSet.count(attr.first.strref()) > 0)
|
|
|
|
|
attrs.push_back(attr);
|
|
|
|
|
}
|
|
|
|
|
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
|
|
|
|
|
p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
|
|
|
|
|
p << *op.rhs() << ", " << *op.acc();
|
2019-11-20 14:43:15 -08:00
|
|
|
if (llvm::size(op.masks()) == 2) {
|
|
|
|
|
p << ", " << **op.masks().begin();
|
|
|
|
|
p << ", " << **(op.masks().begin() + 1);
|
|
|
|
|
}
|
2019-11-25 12:39:30 -08:00
|
|
|
p.printOptionalAttrDict(op.getAttrs(), attrNames);
|
2019-11-20 14:43:15 -08:00
|
|
|
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
|
|
|
|
|
<< op.getResultType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
|
|
|
|
|
const std::vector<std::pair<int64_t, int64_t>> &map) {
|
|
|
|
|
for (auto &dimPair : map) {
|
|
|
|
|
if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
|
|
|
|
|
dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
|
|
|
|
|
lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool verifyOutputShape(
|
|
|
|
|
VectorType lhsType, VectorType rhsType, VectorType accType,
|
|
|
|
|
VectorType resType,
|
|
|
|
|
const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
|
|
|
|
|
const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
|
|
|
|
|
DenseSet<int64_t> lhsContractingDimSet;
|
|
|
|
|
DenseSet<int64_t> rhsContractingDimSet;
|
|
|
|
|
for (auto &dimPair : contractingDimMap) {
|
|
|
|
|
lhsContractingDimSet.insert(dimPair.first);
|
|
|
|
|
rhsContractingDimSet.insert(dimPair.second);
|
|
|
|
|
}
|
|
|
|
|
DenseSet<int64_t> rhsBatchDimSet;
|
|
|
|
|
for (auto &dimPair : batchDimMap)
|
|
|
|
|
rhsBatchDimSet.insert(dimPair.second);
|
|
|
|
|
|
|
|
|
|
// Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
|
|
|
|
|
SmallVector<int64_t, 4> expectedResultDims;
|
|
|
|
|
for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
|
|
|
|
|
if (lhsContractingDimSet.count(i) > 0)
|
|
|
|
|
continue;
|
|
|
|
|
expectedResultDims.push_back(lhsType.getDimSize(i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Add free dimensions from 'rhsType' to 'expectedResultDims'.
|
|
|
|
|
for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
|
|
|
|
|
if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
|
|
|
|
|
continue;
|
|
|
|
|
expectedResultDims.push_back(rhsType.getDimSize(i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Verify dimension from 'resType' against 'expectedResultDims'.
|
|
|
|
|
if (resType.getShape().size() != expectedResultDims.size() ||
|
|
|
|
|
accType.getShape().size() != expectedResultDims.size())
|
|
|
|
|
return false;
|
|
|
|
|
for (int64_t i = 0, e = resType.getRank(); i < e; ++i) {
|
|
|
|
|
if (resType.getDimSize(i) != expectedResultDims[i] ||
|
|
|
|
|
accType.getDimSize(i) != expectedResultDims[i])
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static LogicalResult verify(ContractionOp op) {
|
2019-11-20 14:43:15 -08:00
|
|
|
auto lhsType = op.getLhsType();
|
|
|
|
|
auto rhsType = op.getRhsType();
|
|
|
|
|
auto accType = op.getAccType();
|
|
|
|
|
auto resType = op.getResultType();
|
2019-11-25 12:39:30 -08:00
|
|
|
|
|
|
|
|
// Verify that an indexing map was specified for each vector operand.
|
|
|
|
|
if (op.indexing_maps().size() != 3)
|
|
|
|
|
return op.emitOpError("expected an indexing map for each vector operand");
|
|
|
|
|
|
|
|
|
|
// Verify that each index map has 'numIterators' inputs, no symbols, and
|
|
|
|
|
// that the number of map outputs equals the rank of its associated
|
|
|
|
|
// vector operand.
|
|
|
|
|
unsigned numIterators = op.iterator_types().getValue().size();
|
|
|
|
|
for (auto it : llvm::enumerate(op.indexing_maps())) {
|
|
|
|
|
auto index = it.index();
|
|
|
|
|
auto map = it.value().cast<AffineMapAttr>().getValue();
|
|
|
|
|
if (map.getNumSymbols() != 0)
|
|
|
|
|
return op.emitOpError("expected indexing map ")
|
|
|
|
|
<< index << " to have no symbols";
|
|
|
|
|
if (map.getNumDims() != numIterators)
|
|
|
|
|
return op.emitOpError("expected indexing map ")
|
|
|
|
|
<< index << " to have " << numIterators << " number of inputs";
|
|
|
|
|
auto operandType = op.getOperand(index)->getType().cast<VectorType>();
|
|
|
|
|
unsigned rank = operandType.getShape().size();
|
|
|
|
|
if (map.getNumResults() != rank)
|
|
|
|
|
return op.emitOpError("expected indexing map ")
|
|
|
|
|
<< index << " to have " << rank << " number of outputs";
|
|
|
|
|
if (!map.isProjectedPermutation())
|
|
|
|
|
return op.emitOpError("expected indexing map ")
|
|
|
|
|
<< index << " to be a projected permutation of its inputs";
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-20 14:43:15 -08:00
|
|
|
auto contractingDimMap = op.getContractingDimMap();
|
|
|
|
|
auto batchDimMap = op.getBatchDimMap();
|
|
|
|
|
|
|
|
|
|
// Verify at least one contracting dimension pair was specified.
|
|
|
|
|
if (contractingDimMap.empty())
|
|
|
|
|
return op.emitOpError("expected at least one contracting dimension pair");
|
|
|
|
|
|
|
|
|
|
// Verify contracting dimension map was properly constructed.
|
|
|
|
|
if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
|
|
|
|
|
return op.emitOpError("invalid contracting dimension map");
|
|
|
|
|
|
|
|
|
|
// Verify batch dimension map was properly constructed.
|
|
|
|
|
if (!verifyDimMap(lhsType, rhsType, batchDimMap))
|
|
|
|
|
return op.emitOpError("invalid batch dimension map");
|
|
|
|
|
|
|
|
|
|
// Verify 'accType' and 'resType' shape.
|
|
|
|
|
if (!verifyOutputShape(lhsType, rhsType, accType, resType, contractingDimMap,
|
|
|
|
|
batchDimMap))
|
|
|
|
|
return op.emitOpError("invalid accumulator/result vector shape");
|
|
|
|
|
|
|
|
|
|
// Verify that either two vector masks are set or none are set.
|
|
|
|
|
auto lhsMaskType = op.getLHSVectorMaskType();
|
|
|
|
|
auto rhsMaskType = op.getRHSVectorMaskType();
|
|
|
|
|
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
|
|
|
|
|
return op.emitOpError("invalid number of vector masks specified");
|
|
|
|
|
if (lhsMaskType && rhsMaskType) {
|
2019-12-06 07:36:55 -08:00
|
|
|
// Verify mask rank == argument rank.
|
|
|
|
|
if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
|
|
|
|
|
rhsMaskType.getShape().size() != rhsType.getShape().size())
|
|
|
|
|
return op.emitOpError("invalid vector mask rank");
|
2019-11-20 14:43:15 -08:00
|
|
|
}
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-25 12:39:30 -08:00
|
|
|
SmallVector<StringRef, 2> ContractionOp::getTraitAttrNames() {
|
|
|
|
|
return SmallVector<StringRef, 2>{"indexing_maps", "iterator_types"};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
|
|
|
|
|
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
|
|
|
|
|
if (targetExpr == map.getResult(i))
|
|
|
|
|
return i;
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::vector<std::pair<int64_t, int64_t>>
|
|
|
|
|
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
|
|
|
|
|
StringRef targetIteratorTypeName, MLIRContext *context) {
|
2019-11-20 14:43:15 -08:00
|
|
|
std::vector<std::pair<int64_t, int64_t>> dimMap;
|
2019-11-25 12:39:30 -08:00
|
|
|
for (auto it : llvm::enumerate(iteratorTypes)) {
|
|
|
|
|
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
|
|
|
|
|
if (iteratorTypeName != targetIteratorTypeName)
|
|
|
|
|
continue;
|
|
|
|
|
// Search lhs/rhs map results for 'targetExpr'.
|
|
|
|
|
auto targetExpr = getAffineDimExpr(it.index(), context);
|
|
|
|
|
int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
|
|
|
|
|
int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
|
|
|
|
|
if (lhsDim >= 0 && rhsDim >= 0)
|
|
|
|
|
dimMap.push_back({lhsDim, rhsDim});
|
2019-11-20 14:43:15 -08:00
|
|
|
}
|
|
|
|
|
return dimMap;
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-04 06:53:07 -08:00
|
|
|
void ContractionOp::getIterationBounds(
|
|
|
|
|
SmallVectorImpl<int64_t> &iterationBounds) {
|
|
|
|
|
auto lhsShape = getLhsType().getShape();
|
|
|
|
|
auto resShape = getResultType().getShape();
|
|
|
|
|
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
|
|
|
|
SmallVector<int64_t, 2> iterationShape;
|
|
|
|
|
for (auto it : llvm::enumerate(iterator_types())) {
|
|
|
|
|
// Search lhs/rhs map results for 'targetExpr'.
|
|
|
|
|
auto targetExpr = getAffineDimExpr(it.index(), getContext());
|
|
|
|
|
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
|
|
|
|
|
if (iteratorTypeName == getReductionIteratorTypeName()) {
|
|
|
|
|
// Get reduction dim size from lhs shape (same size in rhsShape).
|
|
|
|
|
int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
|
|
|
|
|
assert(lhsDimIndex >= 0);
|
|
|
|
|
iterationBounds.push_back(lhsShape[lhsDimIndex]);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// Get parallel dimension size from result shape.
|
|
|
|
|
int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
|
|
|
|
|
assert(resDimIndex >= 0);
|
|
|
|
|
iterationBounds.push_back(resShape[resDimIndex]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ContractionOp::getIterationIndexMap(
|
|
|
|
|
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
|
|
|
|
|
unsigned numMaps = indexing_maps().getValue().size();
|
|
|
|
|
iterationIndexMap.resize(numMaps);
|
|
|
|
|
for (auto it : llvm::enumerate(indexing_maps())) {
|
|
|
|
|
auto index = it.index();
|
|
|
|
|
auto map = it.value().cast<AffineMapAttr>().getValue();
|
|
|
|
|
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
|
|
|
|
|
auto dim = map.getResult(i).cast<AffineDimExpr>();
|
|
|
|
|
iterationIndexMap[index][dim.getPosition()] = i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
|
2019-11-25 12:39:30 -08:00
|
|
|
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
|
|
|
|
return getDimMap(indexingMaps, iterator_types(),
|
|
|
|
|
getReductionIteratorTypeName(), getContext());
|
2019-11-20 14:43:15 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
|
2019-11-25 12:39:30 -08:00
|
|
|
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
|
|
|
|
return getDimMap(indexingMaps, iterator_types(),
|
|
|
|
|
getParallelIteratorTypeName(), getContext());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
|
|
|
|
|
SmallVector<AffineMap, 4> res;
|
|
|
|
|
auto mapAttrs = indexing_maps().getValue();
|
|
|
|
|
res.reserve(mapAttrs.size());
|
|
|
|
|
for (auto mapAttr : mapAttrs)
|
|
|
|
|
res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
|
|
|
|
|
return res;
|
2019-11-20 14:43:15 -08:00
|
|
|
}
|
|
|
|
|
|
2019-08-09 05:58:19 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-06 12:38:52 -08:00
|
|
|
// ExtractOp
|
2019-08-09 05:58:19 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
static Type inferExtractOpResultType(VectorType vectorType,
|
|
|
|
|
ArrayAttr position) {
|
2019-11-18 10:38:35 -08:00
|
|
|
if (static_cast<int64_t>(position.size()) == vectorType.getRank())
|
|
|
|
|
return vectorType.getElementType();
|
|
|
|
|
return VectorType::get(vectorType.getShape().drop_front(position.size()),
|
|
|
|
|
vectorType.getElementType());
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
void vector::ExtractOp::build(Builder *builder, OperationState &result,
|
|
|
|
|
Value *source, ArrayRef<int32_t> position) {
|
2019-11-18 10:38:35 -08:00
|
|
|
result.addOperands(source);
|
|
|
|
|
auto positionAttr = builder->getI32ArrayAttr(position);
|
2019-12-06 12:38:52 -08:00
|
|
|
result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
|
|
|
|
|
positionAttr));
|
2019-11-18 10:38:35 -08:00
|
|
|
result.addAttribute(getPositionAttrName(), positionAttr);
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
static void print(OpAsmPrinter &p, vector::ExtractOp op) {
|
2019-09-20 20:43:02 -07:00
|
|
|
p << op.getOperationName() << " " << *op.vector() << op.position();
|
|
|
|
|
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
|
|
|
|
p << " : " << op.vector()->getType();
|
2019-08-09 05:58:19 -07:00
|
|
|
}
|
|
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
|
2019-08-09 05:58:19 -07:00
|
|
|
llvm::SMLoc attributeLoc, typeLoc;
|
|
|
|
|
SmallVector<NamedAttribute, 4> attrs;
|
|
|
|
|
OpAsmParser::OperandType vector;
|
|
|
|
|
Type type;
|
|
|
|
|
Attribute attr;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
|
|
|
|
|
parser.parseAttribute(attr, "position", attrs) ||
|
2019-11-05 13:32:07 -08:00
|
|
|
parser.parseOptionalAttrDict(attrs) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
|
2019-08-09 05:58:19 -07:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto vectorType = type.dyn_cast<VectorType>();
|
|
|
|
|
if (!vectorType)
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(typeLoc, "expected vector type");
|
2019-08-09 05:58:19 -07:00
|
|
|
|
|
|
|
|
auto positionAttr = attr.dyn_cast<ArrayAttr>();
|
|
|
|
|
if (!positionAttr ||
|
|
|
|
|
static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(
|
2019-08-09 05:58:19 -07:00
|
|
|
attributeLoc,
|
2019-11-19 12:22:00 -08:00
|
|
|
"expected position attribute of rank smaller than vector rank");
|
2019-08-09 05:58:19 -07:00
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
Type resType = inferExtractOpResultType(vectorType, positionAttr);
|
2019-09-20 19:47:05 -07:00
|
|
|
result.attributes = attrs;
|
|
|
|
|
return failure(parser.resolveOperand(vector, type, result.operands) ||
|
|
|
|
|
parser.addTypeToList(resType, result.types));
|
2019-08-09 05:58:19 -07:00
|
|
|
}
|
|
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
static LogicalResult verify(vector::ExtractOp op) {
|
2019-08-09 05:58:19 -07:00
|
|
|
auto positionAttr = op.position().getValue();
|
|
|
|
|
if (positionAttr.empty())
|
|
|
|
|
return op.emitOpError("expected non-empty position attribute");
|
|
|
|
|
if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
|
|
|
|
|
return op.emitOpError(
|
2019-11-19 12:22:00 -08:00
|
|
|
"expected position attribute of rank smaller than vector rank");
|
2019-08-09 05:58:19 -07:00
|
|
|
for (auto en : llvm::enumerate(positionAttr)) {
|
|
|
|
|
auto attr = en.value().dyn_cast<IntegerAttr>();
|
|
|
|
|
if (!attr || attr.getInt() < 0 ||
|
|
|
|
|
attr.getInt() > op.getVectorType().getDimSize(en.index()))
|
|
|
|
|
return op.emitOpError("expected position attribute #")
|
|
|
|
|
<< (en.index() + 1)
|
2019-11-25 08:46:37 -08:00
|
|
|
<< " to be a non-negative integer smaller than the corresponding "
|
2019-08-09 05:58:19 -07:00
|
|
|
"vector dimension";
|
|
|
|
|
}
|
|
|
|
|
return success();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
2019-11-18 10:38:35 -08:00
|
|
|
|
2019-11-26 14:43:03 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// BroadcastOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
static void print(OpAsmPrinter &p, BroadcastOp op) {
|
2019-11-26 19:52:02 -08:00
|
|
|
p << op.getOperationName() << " " << *op.source();
|
2019-11-26 14:43:03 -08:00
|
|
|
p << " : " << op.getSourceType();
|
2019-11-26 19:52:02 -08:00
|
|
|
p << " to " << op.getVectorType();
|
2019-11-26 14:43:03 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static LogicalResult verify(BroadcastOp op) {
|
|
|
|
|
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
|
2019-11-26 19:52:02 -08:00
|
|
|
VectorType dstVectorType = op.getVectorType();
|
2019-11-26 14:43:03 -08:00
|
|
|
// Scalar to vector broadcast is always valid. A vector
|
|
|
|
|
// to vector broadcast needs some additional checking.
|
|
|
|
|
if (srcVectorType) {
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
int64_t srcRank = srcVectorType.getRank();
|
|
|
|
|
int64_t dstRank = dstVectorType.getRank();
|
2019-12-02 09:56:58 -08:00
|
|
|
if (srcRank > dstRank)
|
2019-11-26 14:43:03 -08:00
|
|
|
return op.emitOpError("source rank higher than destination rank");
|
2019-12-02 09:56:58 -08:00
|
|
|
// Source has an exact match or singleton value for all trailing dimensions
|
|
|
|
|
// (all leading dimensions are simply duplicated).
|
[VectorOps] Add lowering of vector.broadcast to LLVM IR
For example, a scalar broadcast
%0 = vector.broadcast %x : f32 to vector<2xf32>
return %0 : vector<2xf32>
which expands scalar x into vector [x,x] by lowering
to the following LLVM IR dialect to implement the
duplication over the leading dimension.
%0 = llvm.mlir.undef : !llvm<"<2 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.insertelement %x, %0[%1 : !llvm.i64] : !llvm<"<2 x float>">
%3 = llvm.shufflevector %2, %0 [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
return %3 : vector<2xf32>
In the trailing dimensions, the operand is simply
"passed through", unless a more elaborate "stretch"
is required.
For example
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32>
return %0 : vector<4xf32>
becomes
%0 = llvm.mlir.undef : !llvm<"<4 x float>">
%1 = llvm.mlir.constant(0 : index) : !llvm.i64
%2 = llvm.extractelement %arg0[%1 : !llvm.i64] : !llvm<"<1 x float>">
%3 = llvm.mlir.constant(0 : index) : !llvm.i64
%4 = llvm.insertelement %2, %0[%3 : !llvm.i64] : !llvm<"<4 x float>">
%5 = llvm.shufflevector %4, %0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
llvm.return %5 : !llvm<"<4 x float>">
PiperOrigin-RevId: 284219926
2019-12-06 11:01:54 -08:00
|
|
|
int64_t lead = dstRank - srcRank;
|
|
|
|
|
for (int64_t r = 0; r < srcRank; ++r) {
|
|
|
|
|
int64_t srcDim = srcVectorType.getDimSize(r);
|
|
|
|
|
int64_t dstDim = dstVectorType.getDimSize(lead + r);
|
2019-12-02 09:56:58 -08:00
|
|
|
if (srcDim != 1 && srcDim != dstDim)
|
|
|
|
|
return op.emitOpError("dimension mismatch (")
|
|
|
|
|
<< srcDim << " vs. " << dstDim << ")";
|
2019-11-26 14:43:03 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ParseResult parseBroadcastOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result) {
|
2019-11-26 19:52:02 -08:00
|
|
|
OpAsmParser::OperandType source;
|
2019-11-26 14:43:03 -08:00
|
|
|
Type sourceType;
|
2019-11-26 19:52:02 -08:00
|
|
|
VectorType vectorType;
|
|
|
|
|
return failure(parser.parseOperand(source) ||
|
2019-11-26 14:43:03 -08:00
|
|
|
parser.parseColonType(sourceType) ||
|
2019-11-26 19:52:02 -08:00
|
|
|
parser.parseKeywordType("to", vectorType) ||
|
2019-11-26 14:43:03 -08:00
|
|
|
parser.resolveOperand(source, sourceType, result.operands) ||
|
2019-11-26 19:52:02 -08:00
|
|
|
parser.addTypeToList(vectorType, result.types));
|
2019-11-26 14:43:03 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-25 08:46:37 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-06 12:38:52 -08:00
|
|
|
// InsertOp
|
2019-11-25 08:46:37 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
void InsertOp::build(Builder *builder, OperationState &result, Value *source,
|
|
|
|
|
Value *dest, ArrayRef<int32_t> position) {
|
2019-11-25 08:46:37 -08:00
|
|
|
result.addOperands({source, dest});
|
|
|
|
|
auto positionAttr = builder->getI32ArrayAttr(position);
|
|
|
|
|
result.addTypes(dest->getType());
|
|
|
|
|
result.addAttribute(getPositionAttrName(), positionAttr);
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
static void print(OpAsmPrinter &p, InsertOp op) {
|
2019-11-25 08:46:37 -08:00
|
|
|
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
|
|
|
|
|
<< op.position();
|
2019-12-06 12:38:52 -08:00
|
|
|
p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
|
2019-11-25 08:46:37 -08:00
|
|
|
p << " : " << op.getSourceType();
|
|
|
|
|
p << " into " << op.getDestVectorType();
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
static ParseResult parseInsertOp(OpAsmParser &parser, OperationState &result) {
|
2019-11-25 08:46:37 -08:00
|
|
|
SmallVector<NamedAttribute, 4> attrs;
|
|
|
|
|
OpAsmParser::OperandType source, dest;
|
|
|
|
|
Type sourceType;
|
|
|
|
|
VectorType destType;
|
|
|
|
|
Attribute attr;
|
|
|
|
|
return failure(parser.parseOperand(source) || parser.parseComma() ||
|
|
|
|
|
parser.parseOperand(dest) ||
|
2019-12-06 12:38:52 -08:00
|
|
|
parser.parseAttribute(attr, InsertOp::getPositionAttrName(),
|
2019-11-25 08:46:37 -08:00
|
|
|
result.attributes) ||
|
|
|
|
|
parser.parseOptionalAttrDict(attrs) ||
|
|
|
|
|
parser.parseColonType(sourceType) ||
|
|
|
|
|
parser.parseKeywordType("into", destType) ||
|
|
|
|
|
parser.resolveOperand(source, sourceType, result.operands) ||
|
|
|
|
|
parser.resolveOperand(dest, destType, result.operands) ||
|
|
|
|
|
parser.addTypeToList(destType, result.types));
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-06 12:38:52 -08:00
|
|
|
static LogicalResult verify(InsertOp op) {
|
2019-11-25 08:46:37 -08:00
|
|
|
auto positionAttr = op.position().getValue();
|
|
|
|
|
if (positionAttr.empty())
|
|
|
|
|
return op.emitOpError("expected non-empty position attribute");
|
|
|
|
|
auto destVectorType = op.getDestVectorType();
|
|
|
|
|
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"expected position attribute of rank smaller than dest vector rank");
|
|
|
|
|
auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
|
|
|
|
|
if (srcVectorType &&
|
|
|
|
|
(static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
|
|
|
|
|
static_cast<unsigned>(destVectorType.getRank())))
|
|
|
|
|
return op.emitOpError("expected position attribute rank + source rank to "
|
|
|
|
|
"match dest vector rank");
|
|
|
|
|
else if (!srcVectorType && (positionAttr.size() !=
|
|
|
|
|
static_cast<unsigned>(destVectorType.getRank())))
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"expected position attribute rank to match the dest vector rank");
|
|
|
|
|
for (auto en : llvm::enumerate(positionAttr)) {
|
|
|
|
|
auto attr = en.value().dyn_cast<IntegerAttr>();
|
|
|
|
|
if (!attr || attr.getInt() < 0 ||
|
|
|
|
|
attr.getInt() > destVectorType.getDimSize(en.index()))
|
|
|
|
|
return op.emitOpError("expected position attribute #")
|
|
|
|
|
<< (en.index() + 1)
|
|
|
|
|
<< " to be a non-negative integer smaller than the corresponding "
|
|
|
|
|
"dest vector dimension";
|
|
|
|
|
}
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-19 12:22:00 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-25 15:36:45 -08:00
|
|
|
// InsertStridedSliceOp
|
2019-11-19 12:22:00 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-11-25 15:36:45 -08:00
|
|
|
void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
|
|
|
|
|
Value *source, Value *dest,
|
|
|
|
|
ArrayRef<int64_t> offsets,
|
|
|
|
|
ArrayRef<int64_t> strides) {
|
|
|
|
|
result.addOperands({source, dest});
|
2019-11-19 12:22:00 -08:00
|
|
|
auto offsetsAttr = builder->getI64ArrayAttr(offsets);
|
|
|
|
|
auto stridesAttr = builder->getI64ArrayAttr(strides);
|
2019-11-25 15:36:45 -08:00
|
|
|
result.addTypes(dest->getType());
|
2019-11-19 12:22:00 -08:00
|
|
|
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
|
|
|
|
|
result.addAttribute(getStridesAttrName(), stridesAttr);
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-25 15:36:45 -08:00
|
|
|
static void print(OpAsmPrinter &p, InsertStridedSliceOp op) {
|
|
|
|
|
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
|
|
|
|
|
<< " ";
|
2019-11-19 12:22:00 -08:00
|
|
|
p.printOptionalAttrDict(op.getAttrs());
|
2019-11-25 15:36:45 -08:00
|
|
|
p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType();
|
2019-11-19 12:22:00 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-25 15:36:45 -08:00
|
|
|
static ParseResult parseInsertStridedSliceOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result) {
|
|
|
|
|
OpAsmParser::OperandType source, dest;
|
|
|
|
|
VectorType sourceVectorType, destVectorType;
|
|
|
|
|
return failure(
|
|
|
|
|
parser.parseOperand(source) || parser.parseComma() ||
|
|
|
|
|
parser.parseOperand(dest) ||
|
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
|
|
|
parser.parseColonType(sourceVectorType) ||
|
|
|
|
|
parser.parseKeywordType("into", destVectorType) ||
|
|
|
|
|
parser.resolveOperand(source, sourceVectorType, result.operands) ||
|
|
|
|
|
parser.resolveOperand(dest, destVectorType, result.operands) ||
|
|
|
|
|
parser.addTypeToList(destVectorType, result.types));
|
2019-11-19 12:22:00 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(ntv) Should be moved to Tablegen Confined attributes.
|
2019-11-25 15:36:45 -08:00
|
|
|
template <typename OpType>
|
|
|
|
|
LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr,
|
|
|
|
|
ArrayRef<int64_t> shape,
|
|
|
|
|
StringRef attrName) {
|
|
|
|
|
if (arrayAttr.size() > shape.size())
|
|
|
|
|
return op.emitOpError("expected ")
|
|
|
|
|
<< attrName << " attribute of rank smaller than vector rank";
|
|
|
|
|
return success();
|
2019-11-19 12:22:00 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
|
|
|
|
|
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
|
|
|
|
// Otherwise, the admissible interval is [min, max].
|
2019-11-25 15:36:45 -08:00
|
|
|
template <typename OpType>
|
|
|
|
|
LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr,
|
|
|
|
|
int64_t min, int64_t max,
|
|
|
|
|
StringRef attrName,
|
|
|
|
|
bool halfOpen = true) {
|
2019-11-19 12:22:00 -08:00
|
|
|
for (auto attr : arrayAttr) {
|
|
|
|
|
auto val = attr.cast<IntegerAttr>().getInt();
|
|
|
|
|
auto upper = max;
|
|
|
|
|
if (!halfOpen)
|
|
|
|
|
upper += 1;
|
2019-11-25 15:36:45 -08:00
|
|
|
if (val < min || val >= upper)
|
|
|
|
|
return op.emitOpError("expected ") << attrName << " to be confined to ["
|
|
|
|
|
<< min << ", " << upper << ")";
|
2019-11-19 12:22:00 -08:00
|
|
|
}
|
2019-11-25 15:36:45 -08:00
|
|
|
return success();
|
2019-11-19 12:22:00 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
|
|
|
|
|
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
|
|
|
|
// Otherwise, the admissible interval is [min, max].
|
2019-11-25 15:36:45 -08:00
|
|
|
template <typename OpType>
|
|
|
|
|
LogicalResult
|
|
|
|
|
isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
|
|
|
|
|
ArrayRef<int64_t> shape, StringRef attrName,
|
2019-11-19 12:22:00 -08:00
|
|
|
bool halfOpen = true, int64_t min = 0) {
|
2019-11-25 15:36:45 -08:00
|
|
|
assert(arrayAttr.size() <= shape.size());
|
|
|
|
|
unsigned index = 0;
|
|
|
|
|
for (auto it : llvm::zip(arrayAttr, shape)) {
|
2019-11-19 12:22:00 -08:00
|
|
|
auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
|
|
|
|
|
auto max = std::get<1>(it);
|
|
|
|
|
if (!halfOpen)
|
|
|
|
|
max += 1;
|
2019-11-25 15:36:45 -08:00
|
|
|
if (val < min || val >= max)
|
|
|
|
|
return op.emitOpError("expected ")
|
|
|
|
|
<< attrName << " dimension " << index << " to be confined to ["
|
|
|
|
|
<< min << ", " << max << ")";
|
|
|
|
|
++index;
|
2019-11-19 12:22:00 -08:00
|
|
|
}
|
2019-11-25 15:36:45 -08:00
|
|
|
return success();
|
2019-11-19 12:22:00 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Returns true if all integers in `arrayAttr` are in the interval [min, max}.
|
|
|
|
|
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
|
|
|
|
// Otherwise, the admissible interval is [min, max].
|
2019-11-25 15:36:45 -08:00
|
|
|
template <typename OpType>
|
|
|
|
|
LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
|
|
|
|
|
OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
|
|
|
|
|
ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
|
|
|
|
|
bool halfOpen = true, int64_t min = 1) {
|
|
|
|
|
assert(arrayAttr1.size() <= shape.size());
|
|
|
|
|
assert(arrayAttr2.size() <= shape.size());
|
|
|
|
|
unsigned index = 0;
|
|
|
|
|
for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
|
2019-11-19 12:22:00 -08:00
|
|
|
auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
|
|
|
|
|
auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
|
|
|
|
|
auto max = std::get<2>(it);
|
|
|
|
|
if (!halfOpen)
|
|
|
|
|
max += 1;
|
2019-11-25 15:36:45 -08:00
|
|
|
if (val1 + val2 < 0 || val1 + val2 >= max)
|
|
|
|
|
return op.emitOpError("expected sum(")
|
|
|
|
|
<< attrName1 << ", " << attrName2 << ") dimension " << index
|
|
|
|
|
<< " to be confined to [" << min << ", " << max << ")";
|
|
|
|
|
++index;
|
2019-11-19 12:22:00 -08:00
|
|
|
}
|
2019-11-25 15:36:45 -08:00
|
|
|
return success();
|
2019-11-19 12:22:00 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-25 15:36:45 -08:00
|
|
|
static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
|
|
|
|
|
MLIRContext *context) {
|
|
|
|
|
auto attrs = functional::map(
|
|
|
|
|
[context](int64_t v) -> Attribute {
|
|
|
|
|
return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
|
|
|
|
|
},
|
|
|
|
|
values);
|
|
|
|
|
return ArrayAttr::get(attrs, context);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static LogicalResult verify(InsertStridedSliceOp op) {
|
|
|
|
|
auto sourceVectorType = op.getSourceVectorType();
|
|
|
|
|
auto destVectorType = op.getDestVectorType();
|
2019-11-19 12:22:00 -08:00
|
|
|
auto offsets = op.offsets();
|
|
|
|
|
auto strides = op.strides();
|
2019-11-25 15:56:06 -08:00
|
|
|
if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"expected offsets of same size as destination vector rank");
|
|
|
|
|
if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"expected strides of same size as source vector rank");
|
|
|
|
|
if (sourceVectorType.getRank() > destVectorType.getRank())
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"expected source rank to be smaller than destination rank");
|
2019-11-19 12:22:00 -08:00
|
|
|
|
2019-11-25 15:36:45 -08:00
|
|
|
auto sourceShape = sourceVectorType.getShape();
|
|
|
|
|
auto destShape = destVectorType.getShape();
|
|
|
|
|
SmallVector<int64_t, 4> sourceShapeAsDestShape(
|
|
|
|
|
destShape.size() - sourceShape.size(), 0);
|
|
|
|
|
sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
|
|
|
|
|
auto offName = InsertStridedSliceOp::getOffsetsAttrName();
|
|
|
|
|
auto stridesName = InsertStridedSliceOp::getStridesAttrName();
|
|
|
|
|
if (failed(
|
|
|
|
|
isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
|
|
|
|
|
failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
|
|
|
|
|
/*halfOpen=*/false)) ||
|
|
|
|
|
failed(isSumOfIntegerArrayAttrConfinedToShape(
|
|
|
|
|
op, offsets,
|
|
|
|
|
makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
|
|
|
|
|
offName, "source vector shape",
|
|
|
|
|
/*halfOpen=*/false, /*min=*/1)))
|
|
|
|
|
return failure();
|
|
|
|
|
|
2019-11-19 12:22:00 -08:00
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2019-08-09 06:55:10 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-22 07:52:02 -08:00
|
|
|
// OuterProductOp
|
2019-08-09 06:55:10 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static void print(OpAsmPrinter &p, OuterProductOp op) {
|
2019-09-20 20:43:02 -07:00
|
|
|
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (llvm::size(op.acc()) > 0)
|
2019-09-20 20:43:02 -07:00
|
|
|
p << ", " << **op.acc().begin();
|
|
|
|
|
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
|
2019-08-09 06:55:10 -07:00
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static ParseResult parseOuterProductOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result) {
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
|
|
|
|
|
Type tLHS, tRHS;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
|
|
|
|
|
parser.parseComma() || parser.parseType(tRHS))
|
2019-08-09 06:55:10 -07:00
|
|
|
return failure();
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (operandsInfo.size() < 2)
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(parser.getNameLoc(),
|
|
|
|
|
"expected at least 2 operands");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
VectorType vLHS = tLHS.dyn_cast<VectorType>();
|
|
|
|
|
VectorType vRHS = tRHS.dyn_cast<VectorType>();
|
|
|
|
|
if (!vLHS || !vRHS)
|
2019-09-20 11:36:49 -07:00
|
|
|
return parser.emitError(parser.getNameLoc(), "expected 2 vector types");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
|
|
|
|
|
vLHS.getElementType());
|
|
|
|
|
return failure(
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
|
|
|
|
|
parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
(operandsInfo.size() > 2 &&
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
|
|
|
|
|
parser.addTypeToList(resType, result.types));
|
2019-08-09 06:55:10 -07:00
|
|
|
}
|
2018-12-14 09:31:17 -08:00
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static LogicalResult verify(OuterProductOp op) {
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
VectorType vLHS = op.getOperandVectorTypeLHS(),
|
|
|
|
|
vRHS = op.getOperandVectorTypeRHS(),
|
|
|
|
|
vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
|
|
|
|
|
if (vLHS.getRank() != 1)
|
2019-08-09 06:55:10 -07:00
|
|
|
return op.emitOpError("expected 1-d vector for operand #1");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (vRHS.getRank() != 1)
|
2019-08-09 06:55:10 -07:00
|
|
|
return op.emitOpError("expected 1-d vector for operand #2");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (vRES.getRank() != 2)
|
2019-08-09 06:55:10 -07:00
|
|
|
return op.emitOpError("expected 2-d vector result");
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
if (vLHS.getDimSize(0) != vRES.getDimSize(0))
|
|
|
|
|
return op.emitOpError("expected #1 operand dim to match result dim #1");
|
|
|
|
|
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
|
|
|
|
|
return op.emitOpError("expected #2 operand dim to match result dim #2");
|
|
|
|
|
if (vACC && vACC != vRES)
|
|
|
|
|
return op.emitOpError("expected operand #3 of same type as result type");
|
2019-08-09 06:55:10 -07:00
|
|
|
return success();
|
|
|
|
|
}
|
Extend vector.outerproduct with an optional 3rd argument
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and is lowered to FMA intrinsic when the lowering supports it.
In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).
This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.
This has been independently verified to result in proper fma instructions for haswell as follows.
Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
return %2 : vector<17x8xf32>
}
}
```
Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```
Output:
```
outerproduct_add: # @outerproduct_add
# %bb.0:
...
vmovaps 112(%rbp), %ymm8
vbroadcastss %xmm0, %ymm0
...
vbroadcastss 64(%rbp), %ymm15
vfmadd213ps 144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
...
vfmadd213ps 400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
...
```
PiperOrigin-RevId: 263743359
2019-08-16 03:52:56 -07:00
|
|
|
|
2019-11-25 15:36:45 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// StridedSliceOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
// Inference works as follows:
|
|
|
|
|
// 1. Add 'sizes' from prefix of dims in 'offsets'.
|
|
|
|
|
// 2. Add sizes from 'vectorType' for remaining dims.
|
|
|
|
|
static Type inferStridedSliceOpResultType(VectorType vectorType,
|
|
|
|
|
ArrayAttr offsets, ArrayAttr sizes,
|
|
|
|
|
ArrayAttr strides) {
|
|
|
|
|
assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
|
|
|
|
|
SmallVector<int64_t, 4> shape;
|
|
|
|
|
shape.reserve(vectorType.getRank());
|
|
|
|
|
unsigned idx = 0;
|
|
|
|
|
for (unsigned e = offsets.size(); idx < e; ++idx)
|
|
|
|
|
shape.push_back(sizes.getValue()[idx].cast<IntegerAttr>().getInt());
|
|
|
|
|
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
|
|
|
|
|
shape.push_back(vectorType.getShape()[idx]);
|
|
|
|
|
|
|
|
|
|
return VectorType::get(shape, vectorType.getElementType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void StridedSliceOp::build(Builder *builder, OperationState &result,
|
|
|
|
|
Value *source, ArrayRef<int64_t> offsets,
|
|
|
|
|
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
|
|
|
|
|
result.addOperands(source);
|
|
|
|
|
auto offsetsAttr = builder->getI64ArrayAttr(offsets);
|
|
|
|
|
auto sizesAttr = builder->getI64ArrayAttr(sizes);
|
|
|
|
|
auto stridesAttr = builder->getI64ArrayAttr(strides);
|
|
|
|
|
result.addTypes(
|
|
|
|
|
inferStridedSliceOpResultType(source->getType().cast<VectorType>(),
|
|
|
|
|
offsetsAttr, sizesAttr, stridesAttr));
|
|
|
|
|
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
|
|
|
|
|
result.addAttribute(getSizesAttrName(), sizesAttr);
|
|
|
|
|
result.addAttribute(getStridesAttrName(), stridesAttr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void print(OpAsmPrinter &p, StridedSliceOp op) {
|
|
|
|
|
p << op.getOperationName() << " " << *op.vector();
|
|
|
|
|
p.printOptionalAttrDict(op.getAttrs());
|
|
|
|
|
p << " : " << op.vector()->getType() << " to " << op.getResult()->getType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static ParseResult parseStridedSliceOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result) {
|
|
|
|
|
llvm::SMLoc attributeLoc, typeLoc;
|
|
|
|
|
OpAsmParser::OperandType vector;
|
|
|
|
|
VectorType vectorType, resultVectorType;
|
|
|
|
|
return failure(parser.parseOperand(vector) ||
|
|
|
|
|
parser.getCurrentLocation(&attributeLoc) ||
|
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
|
|
|
parser.getCurrentLocation(&typeLoc) ||
|
|
|
|
|
parser.parseColonType(vectorType) ||
|
|
|
|
|
parser.parseKeywordType("to", resultVectorType) ||
|
|
|
|
|
parser.resolveOperand(vector, vectorType, result.operands) ||
|
|
|
|
|
parser.addTypeToList(resultVectorType, result.types));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static LogicalResult verify(StridedSliceOp op) {
|
|
|
|
|
auto type = op.getVectorType();
|
|
|
|
|
auto offsets = op.offsets();
|
|
|
|
|
auto sizes = op.sizes();
|
|
|
|
|
auto strides = op.strides();
|
|
|
|
|
if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
|
|
|
|
|
op.emitOpError(
|
|
|
|
|
"expected offsets, sizes and strides attributes of same size");
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto shape = type.getShape();
|
|
|
|
|
auto offName = StridedSliceOp::getOffsetsAttrName();
|
|
|
|
|
auto sizesName = StridedSliceOp::getSizesAttrName();
|
|
|
|
|
auto stridesName = StridedSliceOp::getStridesAttrName();
|
|
|
|
|
if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
|
|
|
|
|
failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
|
|
|
|
|
failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
|
|
|
|
|
stridesName)) ||
|
|
|
|
|
failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
|
|
|
|
|
failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
|
|
|
|
|
/*halfOpen=*/false,
|
|
|
|
|
/*min=*/1)) ||
|
|
|
|
|
failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
|
|
|
|
|
/*halfOpen=*/false)) ||
|
|
|
|
|
failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
|
|
|
|
|
offName, sizesName,
|
|
|
|
|
/*halfOpen=*/false)))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto resultType = inferStridedSliceOpResultType(
|
|
|
|
|
op.getVectorType(), op.offsets(), op.sizes(), op.strides());
|
|
|
|
|
if (op.getResult()->getType() != resultType) {
|
|
|
|
|
op.emitOpError("expected result type to be ") << resultType;
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-04 13:00:14 -08:00
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
|
|
|
|
|
SmallVectorImpl<int64_t> &results) {
|
|
|
|
|
for (auto attr : arrayAttr)
|
|
|
|
|
results.push_back(attr.cast<IntegerAttr>().getInt());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
|
|
|
|
|
class StridedSliceConstantMaskFolder final
|
|
|
|
|
: public OpRewritePattern<StridedSliceOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
|
|
PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
// Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp.
|
|
|
|
|
auto defOp = stridedSliceOp.vector()->getDefiningOp();
|
|
|
|
|
auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
|
|
|
|
|
if (!constantMaskOp)
|
|
|
|
|
return matchFailure();
|
|
|
|
|
// Return if 'stridedSliceOp' has non-unit strides.
|
|
|
|
|
if (llvm::any_of(stridedSliceOp.strides(), [](Attribute attr) {
|
|
|
|
|
return attr.cast<IntegerAttr>().getInt() != 1;
|
|
|
|
|
}))
|
|
|
|
|
return matchFailure();
|
|
|
|
|
// Gather constant mask dimension sizes.
|
|
|
|
|
SmallVector<int64_t, 4> maskDimSizes;
|
|
|
|
|
populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
|
|
|
|
|
// Gather strided slice offsets and sizes.
|
|
|
|
|
SmallVector<int64_t, 4> sliceOffsets;
|
|
|
|
|
populateFromInt64AttrArray(stridedSliceOp.offsets(), sliceOffsets);
|
|
|
|
|
SmallVector<int64_t, 4> sliceSizes;
|
|
|
|
|
populateFromInt64AttrArray(stridedSliceOp.sizes(), sliceSizes);
|
|
|
|
|
|
|
|
|
|
// Compute slice of vector mask region.
|
|
|
|
|
SmallVector<int64_t, 4> sliceMaskDimSizes;
|
|
|
|
|
assert(sliceOffsets.size() == maskDimSizes.size());
|
|
|
|
|
for (const auto &it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
|
|
|
|
|
int64_t maskDimSize = std::get<0>(it);
|
|
|
|
|
int64_t sliceOffset = std::get<1>(it);
|
|
|
|
|
int64_t sliceSize = std::get<2>(it);
|
|
|
|
|
int64_t sliceMaskDimSize = std::max(
|
|
|
|
|
static_cast<int64_t>(0),
|
|
|
|
|
std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
|
|
|
|
|
sliceMaskDimSizes.push_back(sliceMaskDimSize);
|
|
|
|
|
}
|
|
|
|
|
// If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
|
|
|
|
|
// region is a conjunction of mask dim intervals).
|
|
|
|
|
if (llvm::any_of(sliceMaskDimSizes, [](int64_t sz) { return sz == 0; }))
|
|
|
|
|
sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
|
|
|
|
|
|
|
|
|
|
// Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region.
|
|
|
|
|
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
|
|
|
|
|
stridedSliceOp, stridedSliceOp.getResult()->getType(),
|
|
|
|
|
rewriter.getI64ArrayAttr(sliceMaskDimSizes));
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
|
|
|
|
void StridedSliceOp::getCanonicalizationPatterns(
|
|
|
|
|
OwningRewritePatternList &results, MLIRContext *context) {
|
|
|
|
|
// Pattern to rewrite a StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
|
|
|
|
|
results.insert<StridedSliceConstantMaskFolder>(context);
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-14 09:31:17 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-22 07:52:02 -08:00
|
|
|
// TransferReadOp
|
2018-12-14 09:31:17 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
template <typename EmitFun>
|
2019-04-02 13:09:34 -07:00
|
|
|
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
|
|
|
|
|
EmitFun emitOpError) {
|
2018-12-14 09:31:17 -08:00
|
|
|
SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
|
|
|
|
|
for (auto expr : permutationMap.getResults()) {
|
|
|
|
|
auto dim = expr.dyn_cast<AffineDimExpr>();
|
|
|
|
|
auto zero = expr.dyn_cast<AffineConstantExpr>();
|
|
|
|
|
if (zero) {
|
|
|
|
|
if (zero.getValue() != 0) {
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires a projected permutation_map (at most one dim or the zero "
|
|
|
|
|
"constant can appear in each result)");
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (!dim) {
|
|
|
|
|
return emitOpError("requires a projected permutation_map (at most one "
|
|
|
|
|
"dim or the zero constant can appear in each result)");
|
|
|
|
|
}
|
|
|
|
|
if (seen[dim.getPosition()]) {
|
|
|
|
|
return emitOpError(
|
|
|
|
|
"requires a permutation_map that is a permutation (found one dim "
|
|
|
|
|
"used more than once)");
|
|
|
|
|
}
|
|
|
|
|
seen[dim.getPosition()] = true;
|
|
|
|
|
}
|
2019-04-02 13:09:34 -07:00
|
|
|
return success();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static void print(OpAsmPrinter &p, TransferReadOp op) {
|
2019-11-14 08:10:36 -08:00
|
|
|
p << op.getOperationName() << " ";
|
|
|
|
|
p.printOperand(op.memref());
|
2019-09-20 20:43:02 -07:00
|
|
|
p << "[";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printOperands(op.indices());
|
|
|
|
|
p << "], ";
|
|
|
|
|
p.printOperand(op.padding());
|
|
|
|
|
p << " ";
|
|
|
|
|
p.printOptionalAttrDict(op.getAttrs());
|
|
|
|
|
p << " : " << op.getMemRefType();
|
|
|
|
|
p << ", " << op.getVectorType();
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) {
|
2019-11-14 08:10:36 -08:00
|
|
|
llvm::SMLoc typesLoc;
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
OpAsmParser::OperandType memrefInfo;
|
|
|
|
|
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
2019-11-14 08:10:36 -08:00
|
|
|
OpAsmParser::OperandType paddingInfo;
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
SmallVector<Type, 2> types;
|
2018-12-14 09:31:17 -08:00
|
|
|
// Parsing with support for optional paddingValue.
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(memrefInfo) ||
|
|
|
|
|
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.parseComma() || parser.parseOperand(paddingInfo) ||
|
2019-11-05 13:32:07 -08:00
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure();
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
if (types.size() != 2)
|
2019-11-14 08:10:36 -08:00
|
|
|
return parser.emitError(typesLoc, "two types required");
|
2019-09-20 11:36:49 -07:00
|
|
|
auto indexType = parser.getBuilder().getIndexType();
|
2019-11-14 08:10:36 -08:00
|
|
|
MemRefType memRefType = types[0].dyn_cast<MemRefType>();
|
|
|
|
|
if (!memRefType)
|
|
|
|
|
return parser.emitError(typesLoc, "memref type required"), failure();
|
|
|
|
|
Type vectorType = types[1];
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure(
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperands(indexInfo, indexType, result.operands) ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.resolveOperand(paddingInfo, memRefType.getElementType(),
|
|
|
|
|
result.operands) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.addTypeToList(vectorType, result.types));
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static LogicalResult verify(TransferReadOp op) {
|
2018-12-14 09:31:17 -08:00
|
|
|
// Consistency of elemental types in memref and vector.
|
2019-11-14 08:10:36 -08:00
|
|
|
MemRefType memrefType = op.getMemRefType();
|
|
|
|
|
VectorType vectorType = op.getVectorType();
|
2018-12-14 09:31:17 -08:00
|
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
2019-11-14 08:10:36 -08:00
|
|
|
return op.emitOpError(
|
2018-12-14 09:31:17 -08:00
|
|
|
"requires memref and vector types of the same elemental type");
|
2019-11-14 08:10:36 -08:00
|
|
|
auto elementalType = op.padding()->getType();
|
|
|
|
|
if (!VectorType::isValidElementType(elementalType))
|
|
|
|
|
return op.emitOpError("requires valid padding vector elemental type");
|
|
|
|
|
if (elementalType != vectorType.getElementType())
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"requires formal padding and vector of the same elemental type");
|
|
|
|
|
if (llvm::size(op.indices()) != memrefType.getRank())
|
|
|
|
|
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
|
|
|
|
|
auto permutationMap = op.permutation_map();
|
|
|
|
|
if (permutationMap.getNumSymbols() != 0)
|
|
|
|
|
return op.emitOpError("requires permutation_map without symbols");
|
|
|
|
|
if (permutationMap.getNumInputs() != memrefType.getRank())
|
|
|
|
|
return op.emitOpError("requires a permutation_map with input dims of the "
|
|
|
|
|
"same rank as the memref type");
|
|
|
|
|
if (permutationMap.getNumResults() != vectorType.getRank())
|
|
|
|
|
return op.emitOpError("requires a permutation_map with result dims of the "
|
|
|
|
|
"same rank as the vector type");
|
2018-12-14 09:31:17 -08:00
|
|
|
return verifyPermutationMap(permutationMap,
|
2019-11-14 08:10:36 -08:00
|
|
|
[&op](Twine t) { return op.emitOpError(t); });
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-22 07:52:02 -08:00
|
|
|
// TransferWriteOp
|
2018-12-14 09:31:17 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-22 07:52:02 -08:00
|
|
|
static void print(OpAsmPrinter &p, TransferWriteOp op) {
|
2019-11-14 08:10:36 -08:00
|
|
|
p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref();
|
2019-09-20 20:43:02 -07:00
|
|
|
p << "[";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printOperands(op.indices());
|
2019-09-20 20:43:02 -07:00
|
|
|
p << "]";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printOptionalAttrDict(op.getAttrs());
|
2019-09-20 20:43:02 -07:00
|
|
|
p << " : ";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printType(op.getVectorType());
|
2019-09-20 20:43:02 -07:00
|
|
|
p << ", ";
|
2019-11-14 08:10:36 -08:00
|
|
|
p.printType(op.getMemRefType());
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) {
|
2019-11-14 08:10:36 -08:00
|
|
|
llvm::SMLoc typesLoc;
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
OpAsmParser::OperandType storeValueInfo;
|
2019-11-14 08:10:36 -08:00
|
|
|
OpAsmParser::OperandType memRefInfo;
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
|
|
|
|
SmallVector<Type, 2> types;
|
2019-09-20 11:36:49 -07:00
|
|
|
if (parser.parseOperand(storeValueInfo) || parser.parseComma() ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.parseOperand(memRefInfo) ||
|
2019-09-20 11:36:49 -07:00
|
|
|
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
2019-11-05 13:32:07 -08:00
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure();
|
Cleanup SuperVectorization dialect printing and parsing.
On the read side,
```
%3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: (d0, d1, d2)->(d2, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
```
becomes:
```
%3 = vector_transfer_read %arg0[%i2, %i1, %i0] {permutation_map: (d0, d1, d2)->(d2, d0)} : memref<?x?x?xf32>, vector<32x256xf32>
```
On the write side,
```
vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
```
becomes
```
vector_transfer_write %0, %arg0[%c3, %c3] {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
```
Documentation will be cleaned up in a followup commit that also extracts a proper .md from the top of the file comments.
PiperOrigin-RevId: 241021879
2019-03-29 11:48:20 -07:00
|
|
|
if (types.size() != 2)
|
2019-11-14 08:10:36 -08:00
|
|
|
return parser.emitError(typesLoc, "two types required");
|
|
|
|
|
auto indexType = parser.getBuilder().getIndexType();
|
|
|
|
|
Type vectorType = types[0], memRefType = types[1];
|
2019-05-06 22:01:31 -07:00
|
|
|
return failure(
|
2019-11-14 08:10:36 -08:00
|
|
|
parser.resolveOperand(storeValueInfo, vectorType, result.operands) ||
|
|
|
|
|
parser.resolveOperand(memRefInfo, memRefType, result.operands) ||
|
2019-09-20 19:47:05 -07:00
|
|
|
parser.resolveOperands(indexInfo, indexType, result.operands));
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static LogicalResult verify(TransferWriteOp op) {
|
2018-12-14 09:31:17 -08:00
|
|
|
// Consistency of elemental types in memref and vector.
|
2019-11-14 08:10:36 -08:00
|
|
|
MemRefType memrefType = op.getMemRefType();
|
|
|
|
|
VectorType vectorType = op.getVectorType();
|
2018-12-14 09:31:17 -08:00
|
|
|
if (memrefType.getElementType() != vectorType.getElementType())
|
2019-11-14 08:10:36 -08:00
|
|
|
return op.emitOpError(
|
2018-12-14 09:31:17 -08:00
|
|
|
"requires memref and vector types of the same elemental type");
|
2019-11-14 08:10:36 -08:00
|
|
|
if (llvm::size(op.indices()) != memrefType.getRank())
|
|
|
|
|
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
|
2018-12-14 09:31:17 -08:00
|
|
|
|
|
|
|
|
// Consistency of AffineMap attribute.
|
2019-11-14 08:10:36 -08:00
|
|
|
auto permutationMap = op.permutation_map();
|
|
|
|
|
if (permutationMap.getNumSymbols() != 0)
|
|
|
|
|
return op.emitOpError("requires a symbol-less permutation_map");
|
|
|
|
|
if (permutationMap.getNumInputs() != memrefType.getRank())
|
|
|
|
|
return op.emitOpError("requires a permutation_map with input dims of the "
|
|
|
|
|
"same rank as the memref type: ")
|
|
|
|
|
<< permutationMap.getNumInputs() << " vs " << memrefType;
|
|
|
|
|
if (permutationMap.getNumResults() != vectorType.getRank())
|
|
|
|
|
return op.emitOpError("requires a permutation_map with result dims of the "
|
|
|
|
|
"same rank as the vector type.")
|
|
|
|
|
<< permutationMap.getNumResults() << " vs " << vectorType;
|
2018-12-14 09:31:17 -08:00
|
|
|
return verifyPermutationMap(permutationMap,
|
2019-11-14 08:10:36 -08:00
|
|
|
[&op](Twine t) { return op.emitOpError(t); });
|
2018-12-14 09:31:17 -08:00
|
|
|
}
|
2018-12-17 14:10:52 -08:00
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-22 07:52:02 -08:00
|
|
|
// TypeCastOp
|
2018-12-17 14:10:52 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2019-11-14 08:10:36 -08:00
|
|
|
static MemRefType inferVectorTypeCastResultType(MemRefType t) {
|
|
|
|
|
return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType()));
|
2018-12-17 14:10:52 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
void TypeCastOp::build(Builder *builder, OperationState &result,
|
|
|
|
|
Value *source) {
|
2019-11-14 08:10:36 -08:00
|
|
|
result.addOperands(source);
|
|
|
|
|
result.addTypes(
|
|
|
|
|
inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
|
2018-12-17 14:10:52 -08:00
|
|
|
}
|
|
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static void print(OpAsmPrinter &p, TypeCastOp &op) {
|
2019-11-14 08:10:36 -08:00
|
|
|
auto type = op.getOperand()->getType().cast<MemRefType>();
|
|
|
|
|
p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to "
|
|
|
|
|
<< inferVectorTypeCastResultType(type);
|
|
|
|
|
}
|
2018-12-17 14:10:52 -08:00
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
static LogicalResult verify(TypeCastOp &op) {
|
2019-11-14 08:10:36 -08:00
|
|
|
auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
|
|
|
|
|
if (op.getResultMemRefType() != resultType)
|
|
|
|
|
return op.emitOpError("expects result type to be: ") << resultType;
|
2019-04-02 13:09:34 -07:00
|
|
|
return success();
|
2018-12-17 14:10:52 -08:00
|
|
|
}
|
2019-08-09 05:58:19 -07:00
|
|
|
|
2019-12-04 13:00:14 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// ConstantMaskOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) {
|
|
|
|
|
Type resultType;
|
|
|
|
|
ArrayAttr maskDimSizesAttr;
|
|
|
|
|
StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName();
|
|
|
|
|
return failure(
|
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
|
|
|
parser.parseAttribute(maskDimSizesAttr, attrName, result.attributes) ||
|
|
|
|
|
parser.parseColonType(resultType) ||
|
|
|
|
|
parser.addTypeToList(resultType, result.types));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void print(OpAsmPrinter &p, ConstantMaskOp &op) {
|
|
|
|
|
p << op.getOperationName() << ' ' << op.mask_dim_sizes();
|
|
|
|
|
p << " : " << op.getResult()->getType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static LogicalResult verify(ConstantMaskOp &op) {
|
|
|
|
|
// Verify that array attr size matches the rank of the vector result.
|
|
|
|
|
auto resultType = op.getResult()->getType().cast<VectorType>();
|
|
|
|
|
if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"must specify array attr of size equal vector result rank");
|
|
|
|
|
// Verify that each array attr element is in bounds of corresponding vector
|
|
|
|
|
// result dimension size.
|
|
|
|
|
auto resultShape = resultType.getShape();
|
|
|
|
|
SmallVector<int64_t, 4> maskDimSizes;
|
|
|
|
|
for (auto it : llvm::enumerate(op.mask_dim_sizes())) {
|
|
|
|
|
int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
|
|
|
|
|
if (attrValue < 0 || attrValue > resultShape[it.index()])
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"array attr of size out of bounds of vector result dimension size");
|
|
|
|
|
maskDimSizes.push_back(attrValue);
|
|
|
|
|
}
|
|
|
|
|
// Verify that if one mask dim size is zero, they all should be zero (because
|
|
|
|
|
// the mask region is a conjunction of each mask dimension interval).
|
|
|
|
|
bool any_zeros = llvm::is_contained(maskDimSizes, 0);
|
|
|
|
|
bool all_zeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
|
|
|
|
|
if (any_zeros && !all_zeros)
|
|
|
|
|
return op.emitOpError("expected all mask dim sizes to be zeros, "
|
|
|
|
|
"as a result of conjunction with zero mask dim");
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-03 11:55:09 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// CreateMaskOp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
|
|
|
|
|
auto indexType = parser.getBuilder().getIndexType();
|
|
|
|
|
Type resultType;
|
|
|
|
|
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
|
|
|
|
return failure(
|
|
|
|
|
parser.parseOperandList(operandInfo) ||
|
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
|
|
|
parser.parseColonType(resultType) ||
|
|
|
|
|
parser.resolveOperands(operandInfo, indexType, result.operands) ||
|
|
|
|
|
parser.addTypeToList(resultType, result.types));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void print(OpAsmPrinter &p, CreateMaskOp &op) {
|
|
|
|
|
p << op.getOperationName() << ' ';
|
|
|
|
|
p.printOperands(op.operands());
|
|
|
|
|
p << " : " << op.getResult()->getType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static LogicalResult verify(CreateMaskOp &op) {
|
|
|
|
|
// Verify that an operand was specified for each result vector each dimension.
|
|
|
|
|
if (op.getNumOperands() !=
|
|
|
|
|
op.getResult()->getType().cast<VectorType>().getRank())
|
|
|
|
|
return op.emitOpError(
|
|
|
|
|
"must specify an operand for each result vector dimension");
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-04 13:00:14 -08:00
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
|
|
|
|
|
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
|
|
PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
// Return if any of 'createMaskOp' operands are not defined by a constant.
|
|
|
|
|
auto is_not_def_by_constant = [](Value *operand) {
|
|
|
|
|
return !isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
|
|
|
|
|
};
|
|
|
|
|
if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
|
|
|
|
|
return matchFailure();
|
|
|
|
|
// Gather constant mask dimension sizes.
|
|
|
|
|
SmallVector<int64_t, 4> maskDimSizes;
|
|
|
|
|
for (auto *operand : createMaskOp.operands()) {
|
|
|
|
|
auto defOp = operand->getDefiningOp();
|
|
|
|
|
maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
|
|
|
|
|
}
|
|
|
|
|
// Replace 'createMaskOp' with ConstantMaskOp.
|
|
|
|
|
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
|
|
|
|
|
createMaskOp, createMaskOp.getResult()->getType(),
|
|
|
|
|
rewriter.getI64ArrayAttr(maskDimSizes));
|
|
|
|
|
return matchSuccess();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
|
|
|
|
void CreateMaskOp::getCanonicalizationPatterns(
|
|
|
|
|
OwningRewritePatternList &results, MLIRContext *context) {
|
|
|
|
|
results.insert<CreateMaskFolder>(context);
|
|
|
|
|
}
|
|
|
|
|
|
2019-12-06 07:36:55 -08:00
|
|
|
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
|
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
|
|
|
|
patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
|
2019-11-20 14:43:15 -08:00
|
|
|
}
|
|
|
|
|
|
2019-08-09 05:58:19 -07:00
|
|
|
namespace mlir {
|
2019-11-22 07:52:02 -08:00
|
|
|
namespace vector {
|
2019-08-09 05:58:19 -07:00
|
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
2019-08-19 17:11:12 -07:00
|
|
|
#include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
|
2019-08-09 05:58:19 -07:00
|
|
|
|
2019-11-22 07:52:02 -08:00
|
|
|
} // namespace vector
|
2019-08-09 05:58:19 -07:00
|
|
|
} // namespace mlir
|