Files
llvm/mlir/test/lib/Dialect/Test/TestDialect.cpp
Alex Zinenko 3ba14fa0ce [mlir] Introduce data layout modeling subsystem
Data layout information allows to answer questions about the size and alignment
properties of a type. It enables, among others, the generation of various
linear memory addressing schemes for containers of abstract types and deeper
reasoning about vectors. This introduces the subsystem for modeling data
layouts in MLIR.

The data layout subsystem is designed to scale to MLIR's open type and
operation system. At the top level, it consists of attribute interfaces that
can be implemented by concrete data layout specifications; type interfaces that
should be implemented by types subject to data layout; operation interfaces
that must be implemented by operations that can serve as data layout scopes
(e.g., modules); and dialect interfaces for data layout properties unrelated to
specific types. Built-in types are handled specially to decrease the overall
query cost.

A concrete default implementation of these interfaces is provided in the new
Target dialect. Defaults for built-in types that match the current behavior are
also provided.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D97067
2021-03-11 16:54:47 +01:00

908 lines
33 KiB
C++

//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "TestAttributes.h"
#include "TestTypes.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringSwitch.h"
using namespace mlir;
using namespace mlir::test;
void mlir::test::registerTestDialect(DialectRegistry &registry) {
registry.insert<TestDialect>();
}
//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
// Test support for interacting with the AsmPrinter.
struct TestOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
StringAttr strAttr = attr.dyn_cast<StringAttr>();
if (!strAttr)
return failure();
// Check the contents of the string attribute to see what the test alias
// should be named.
Optional<StringRef> aliasName =
StringSwitch<Optional<StringRef>>(strAttr.getValue())
.Case("alias_test:dot_in_name", StringRef("test.alias"))
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
.Case("alias_test:sanitize_conflict_a",
StringRef("test_alias_conflict0"))
.Case("alias_test:sanitize_conflict_b",
StringRef("test_alias_conflict0_"))
.Default(llvm::None);
if (!aliasName)
return failure();
os << *aliasName;
return success();
}
void getAsmResultNames(Operation *op,
OpAsmSetValueNameFn setNameFn) const final {
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
setNameFn(asmOp, "result");
}
void getAsmBlockArgumentNames(Block *block,
OpAsmSetValueNameFn setNameFn) const final {
auto op = block->getParentOp();
auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
if (!arrayAttr)
return;
auto args = block->getArguments();
auto e = std::min(arrayAttr.size(), args.size());
for (unsigned i = 0; i < e; ++i) {
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
setNameFn(args[i], strAttr.getValue());
}
}
};
struct TestDialectFoldInterface : public DialectFoldInterface {
using DialectFoldInterface::DialectFoldInterface;
/// Registered hook to check if the given region, which is attached to an
/// operation that is *not* isolated from above, should be used when
/// materializing constants.
bool shouldMaterializeInto(Region *region) const final {
// If this is a one region operation, then insert into it.
return isa<OneRegionOp>(region->getParentOp());
}
};
/// This class defines the interface for handling inlining with standard
/// operations.
struct TestInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
// Don't allow inlining calls that are marked `noinline`.
return !call->hasAttr("noinline");
}
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
// Inlining into test dialect regions is legal.
return true;
}
bool isLegalToInline(Operation *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
bool shouldAnalyzeRecursively(Operation *op) const final {
// Analyze recursively if this is not a functional region operation, it
// froms a separate functional scope.
return !isa<FunctionalRegionOp>(op);
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
// Only handle "test.return" here.
auto returnOp = dyn_cast<TestReturnOp>(op);
if (!returnOp)
return;
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
/// Attempt to materialize a conversion for a type mismatch between a call
/// from this dialect, and a callable region. This method should generate an
/// operation that takes 'input' as the only operand, and produces a single
/// result of 'resultType'. If a conversion can not be generated, nullptr
/// should be returned.
Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type resultType,
Location conversionLoc) const final {
// Only allow conversion for i16/i32 types.
if (!(resultType.isSignlessInteger(16) ||
resultType.isSignlessInteger(32)) ||
!(input.getType().isSignlessInteger(16) ||
input.getType().isSignlessInteger(32)))
return nullptr;
return builder.create<TestCastOp>(conversionLoc, resultType, input);
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
void TestDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "TestAttrDefs.cpp.inc"
>();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface>();
addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
#define GET_TYPEDEF_LIST
#include "TestTypeDefs.cpp.inc"
>();
allowUnknownOperations();
}
Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
return builder.create<TestOpConstant>(loc, type, value);
}
LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'";
return success();
}
LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'";
return success();
}
LogicalResult
TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
unsigned resultIndex,
NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'";
return success();
}
//===----------------------------------------------------------------------===//
// TestBranchOp
//===----------------------------------------------------------------------===//
Optional<MutableOperandRange>
TestBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return targetOperandsMutable();
}
//===----------------------------------------------------------------------===//
// TestFoldToCallOp
//===----------------------------------------------------------------------===//
namespace {
struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
LogicalResult matchAndRewrite(FoldToCallOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
ValueRange());
return success();
}
};
} // end anonymous namespace
void FoldToCallOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FoldToCallOpPattern>(context);
}
//===----------------------------------------------------------------------===//
// Test Format* operations
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Parsing
static ParseResult parseCustomDirectiveOperands(
OpAsmParser &parser, OpAsmParser::OperandType &operand,
Optional<OpAsmParser::OperandType> &optOperand,
SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
if (parser.parseOperand(operand))
return failure();
if (succeeded(parser.parseOptionalComma())) {
optOperand.emplace();
if (parser.parseOperand(*optOperand))
return failure();
}
if (parser.parseArrow() || parser.parseLParen() ||
parser.parseOperandList(varOperands) || parser.parseRParen())
return failure();
return success();
}
static ParseResult
parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
Type &optOperandType,
SmallVectorImpl<Type> &varOperandTypes) {
if (parser.parseColon())
return failure();
if (parser.parseType(operandType))
return failure();
if (succeeded(parser.parseOptionalComma())) {
if (parser.parseType(optOperandType))
return failure();
}
if (parser.parseArrow() || parser.parseLParen() ||
parser.parseTypeList(varOperandTypes) || parser.parseRParen())
return failure();
return success();
}
static ParseResult
parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
Type optOperandType,
const SmallVectorImpl<Type> &varOperandTypes) {
if (parser.parseKeyword("type_refs_capture"))
return failure();
Type operandType2, optOperandType2;
SmallVector<Type, 1> varOperandTypes2;
if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
varOperandTypes2))
return failure();
if (operandType != operandType2 || optOperandType != optOperandType2 ||
varOperandTypes != varOperandTypes2)
return failure();
return success();
}
static ParseResult parseCustomDirectiveOperandsAndTypes(
OpAsmParser &parser, OpAsmParser::OperandType &operand,
Optional<OpAsmParser::OperandType> &optOperand,
SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
parseCustomDirectiveResults(parser, operandType, optOperandType,
varOperandTypes))
return failure();
return success();
}
static ParseResult parseCustomDirectiveRegions(
OpAsmParser &parser, Region &region,
SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
if (parser.parseRegion(region))
return failure();
if (failed(parser.parseOptionalComma()))
return success();
std::unique_ptr<Region> varRegion = std::make_unique<Region>();
if (parser.parseRegion(*varRegion))
return failure();
varRegions.emplace_back(std::move(varRegion));
return success();
}
static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
SmallVectorImpl<Block *> &varSuccessors) {
if (parser.parseSuccessor(successor))
return failure();
if (failed(parser.parseOptionalComma()))
return success();
Block *varSuccessor;
if (parser.parseSuccessor(varSuccessor))
return failure();
varSuccessors.append(2, varSuccessor);
return success();
}
static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
IntegerAttr &attr,
IntegerAttr &optAttr) {
if (parser.parseAttribute(attr))
return failure();
if (succeeded(parser.parseOptionalComma())) {
if (parser.parseAttribute(optAttr))
return failure();
}
return success();
}
static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
NamedAttrList &attrs) {
return parser.parseOptionalAttrDict(attrs);
}
static ParseResult parseCustomDirectiveOptionalOperandRef(
OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
int64_t operandCount = 0;
if (parser.parseInteger(operandCount))
return failure();
bool expectedOptionalOperand = operandCount == 0;
return success(expectedOptionalOperand != optOperand.hasValue());
}
//===----------------------------------------------------------------------===//
// Printing
static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
Value operand, Value optOperand,
OperandRange varOperands) {
printer << operand;
if (optOperand)
printer << ", " << optOperand;
printer << " -> (" << varOperands << ")";
}
static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
Type operandType, Type optOperandType,
TypeRange varOperandTypes) {
printer << " : " << operandType;
if (optOperandType)
printer << ", " << optOperandType;
printer << " -> (" << varOperandTypes << ")";
}
static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
Operation *op, Type operandType,
Type optOperandType,
TypeRange varOperandTypes) {
printer << " type_refs_capture ";
printCustomDirectiveResults(printer, op, operandType, optOperandType,
varOperandTypes);
}
static void printCustomDirectiveOperandsAndTypes(
OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
OperandRange varOperands, Type operandType, Type optOperandType,
TypeRange varOperandTypes) {
printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
printCustomDirectiveResults(printer, op, operandType, optOperandType,
varOperandTypes);
}
static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
Region &region,
MutableArrayRef<Region> varRegions) {
printer.printRegion(region);
if (!varRegions.empty()) {
printer << ", ";
for (Region &region : varRegions)
printer.printRegion(region);
}
}
static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
Block *successor,
SuccessorRange varSuccessors) {
printer << successor;
if (!varSuccessors.empty())
printer << ", " << varSuccessors.front();
}
static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
Attribute attribute,
Attribute optAttribute) {
printer << attribute;
if (optAttribute)
printer << ", " << optAttribute;
}
static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
DictionaryAttr attrs) {
printer.printOptionalAttrDict(attrs.getValue());
}
static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
Operation *op,
Value optOperand) {
printer << (optOperand ? "1" : "0");
}
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType argInfo;
Type argType = parser.getBuilder().getIndexType();
// Parse the input operand.
if (parser.parseOperand(argInfo) ||
parser.resolveOperand(argInfo, argType, result.operands))
return failure();
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, argInfo, argType,
/*enableNameShadowing=*/true);
}
static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
p << "test.isolated_region ";
p.printOperand(op.getOperand());
p.shadowRegionArgs(op.region(), op.getOperand());
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// Test SSACFGRegionOp
//===----------------------------------------------------------------------===//
RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
return RegionKind::SSACFG;
}
//===----------------------------------------------------------------------===//
// Test GraphRegionOp
//===----------------------------------------------------------------------===//
static ParseResult parseGraphRegionOp(OpAsmParser &parser,
OperationState &result) {
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
}
static void print(OpAsmPrinter &p, GraphRegionOp op) {
p << "test.graph_region ";
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
RegionKind GraphRegionOp::getRegionKind(unsigned index) {
return RegionKind::Graph;
}
//===----------------------------------------------------------------------===//
// Test AffineScopeOp
//===----------------------------------------------------------------------===//
static ParseResult parseAffineScopeOp(OpAsmParser &parser,
OperationState &result) {
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
}
static void print(OpAsmPrinter &p, AffineScopeOp op) {
p << "test.affine_scope ";
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// Test parser.
//===----------------------------------------------------------------------===//
static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
OperationState &result) {
if (parser.parseOptionalColon())
return success();
uint64_t numResults;
if (parser.parseInteger(numResults))
return failure();
IndexType type = parser.getBuilder().getIndexType();
for (unsigned i = 0; i < numResults; ++i)
result.addTypes(type);
return success();
}
static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
p << ParseIntegerLiteralOp::getOperationName();
if (unsigned numResults = op->getNumResults())
p << " : " << numResults;
}
static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
OperationState &result) {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return failure();
result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
return success();
}
static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
}
//===----------------------------------------------------------------------===//
// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
OperationState &result) {
if (parser.parseKeyword("wraps"))
return failure();
// Parse the wrapped op in a region
Region &body = *result.addRegion();
body.push_back(new Block);
Block &block = body.back();
Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
if (!wrapped_op)
return failure();
// Create a return terminator in the inner region, pass as operand to the
// terminator the returned values from the wrapped operation.
SmallVector<Value, 8> return_operands(wrapped_op->getResults());
OpBuilder builder(parser.getBuilder().getContext());
builder.setInsertionPointToEnd(&block);
builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
// Get the results type for the wrapping op from the terminator operands.
Operation &return_op = body.back().back();
result.types.append(return_op.operand_type_begin(),
return_op.operand_type_end());
// Use the location of the wrapped op for the "test.wrapping_region" op.
result.location = wrapped_op->getLoc();
return success();
}
static void print(OpAsmPrinter &p, WrappingRegionOp op) {
p << op.getOperationName() << " wraps ";
p.printGenericOp(&op.region().front().front());
}
//===----------------------------------------------------------------------===//
// Test PolyForOp - parse list of region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
// Parse list of region arguments without a delimiter.
if (parser.parseRegionArgumentList(ivsInfo))
return failure();
// Parse the body region.
Region *body = result.addRegion();
auto &builder = parser.getBuilder();
SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
return parser.parseRegion(*body, ivsInfo, argTypes);
}
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
namespace {
struct TestRemoveOpWithInnerOps
: public OpRewritePattern<TestOpWithRegionPattern> {
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
} // end anonymous namespace
void TestOpWithRegionPattern::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<TestRemoveOpWithInnerOps>(context);
}
OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
return operand();
}
OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
return getValue();
}
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
for (Value input : this->operands()) {
results.push_back(input);
}
return success();
}
OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1);
if (operands.front()) {
(*this)->setAttr("attr", operands.front());
return getResult();
}
return {};
}
OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
return getOperand();
}
LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
MLIRContext *, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
operands[0].getType(), " vs ",
operands[1].getType());
}
inferredReturnTypes.assign({operands[0].getType()});
return success();
}
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Create return type consisting of the last element of the first operand.
auto operandType = *operands.getTypes().begin();
auto sval = operandType.dyn_cast<ShapedType>();
if (!sval) {
return emitOptionalError(location, "only shaped type operands allowed");
}
int64_t dim =
sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
auto type = IntegerType::get(context, 17);
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
return success();
}
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
shapes = SmallVector<Value, 1>{
builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
return success();
}
//===----------------------------------------------------------------------===//
// Test SideEffect interfaces
//===----------------------------------------------------------------------===//
namespace {
/// A test resource for side effects.
struct TestResource : public SideEffects::Resource::Base<TestResource> {
StringRef getName() final { return "<Test>"; }
};
} // end anonymous namespace
void SideEffectOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Check for an effects attribute on the op instance.
ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
if (!effectsAttr)
return;
// If there is one, it is an array of dictionary attributes that hold
// information on the effects of this operation.
for (Attribute element : effectsAttr) {
DictionaryAttr effectElement = element.cast<DictionaryAttr>();
// Get the specific memory effect.
MemoryEffects::Effect *effect =
StringSwitch<MemoryEffects::Effect *>(
effectElement.get("effect").cast<StringAttr>().getValue())
.Case("allocate", MemoryEffects::Allocate::get())
.Case("free", MemoryEffects::Free::get())
.Case("read", MemoryEffects::Read::get())
.Case("write", MemoryEffects::Write::get());
// Check for a non-default resource to use.
SideEffects::Resource *resource = SideEffects::DefaultResource::get();
if (effectElement.get("test_resource"))
resource = TestResource::get();
// Check for a result to affect.
if (effectElement.get("on_result"))
effects.emplace_back(effect, getResult(), resource);
else if (Attribute ref = effectElement.get("on_reference"))
effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
else
effects.emplace_back(effect, resource);
}
}
void SideEffectOp::getEffects(
SmallVectorImpl<TestEffects::EffectInstance> &effects) {
auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
if (!effectsAttr)
return;
effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
}
//===----------------------------------------------------------------------===//
// StringAttrPrettyNameOp
//===----------------------------------------------------------------------===//
// This op has fancy handling of its SSA result name.
static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
OperationState &result) {
// Add the result types.
for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
result.addTypes(parser.getBuilder().getIntegerType(32));
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
// If the attribute dictionary contains no 'names' attribute, infer it from
// the SSA name (if specified).
bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
return attr.first == "names";
});
// If there was no name specified, check to see if there was a useful name
// specified in the asm file.
if (hadNames || parser.getNumResults() == 0)
return success();
SmallVector<StringRef, 4> names;
auto *context = result.getContext();
for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
auto resultName = parser.getResultName(i);
StringRef nameStr;
if (!resultName.first.empty() && !isdigit(resultName.first[0]))
nameStr = resultName.first;
names.push_back(nameStr);
}
auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
result.attributes.push_back({Identifier::get("names", context), namesAttr});
return success();
}
static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
p << "test.string_attr_pretty_name";
// Note that we only need to print the "name" attribute if the asmprinter
// result name disagrees with it. This can happen in strange cases, e.g.
// when there are conflicts.
bool namesDisagree = op.names().size() != op.getNumResults();
SmallString<32> resultNameStr;
for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
resultNameStr.clear();
llvm::raw_svector_ostream tmpStream(resultNameStr);
p.printOperand(op.getResult(i), tmpStream);
auto expectedName = op.names()[i].dyn_cast<StringAttr>();
if (!expectedName ||
tmpStream.str().drop_front() != expectedName.getValue()) {
namesDisagree = true;
}
}
if (namesDisagree)
p.printOptionalAttrDictWithKeyword(op->getAttrs());
else
p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
}
// We set the SSA name in the asm syntax to the contents of the name
// attribute.
void StringAttrPrettyNameOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
auto value = names();
for (size_t i = 0, e = value.size(); i != e; ++i)
if (auto str = value[i].dyn_cast<StringAttr>())
if (!str.getValue().empty())
setNameFn(getResult(i), str.getValue());
}
//===----------------------------------------------------------------------===//
// RegionIfOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, RegionIfOp op) {
p << RegionIfOp::getOperationName() << " ";
p.printOperands(op.getOperands());
p << ": " << op.getOperandTypes();
p.printArrowTypeList(op.getResultTypes());
p << " then";
p.printRegion(op.thenRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
p << " else";
p.printRegion(op.elseRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
p << " join";
p.printRegion(op.joinRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
}
static ParseResult parseRegionIfOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> operandInfos;
SmallVector<Type, 2> operandTypes;
result.regions.reserve(3);
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
Region *joinRegion = result.addRegion();
// Parse operand, type and arrow type lists.
if (parser.parseOperandList(operandInfos) ||
parser.parseColonTypeList(operandTypes) ||
parser.parseArrowTypeList(result.types))
return failure();
// Parse all attached regions.
if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
return failure();
return parser.resolveOperands(operandInfos, operandTypes,
parser.getCurrentLocation(), result.operands);
}
OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
assert(index < 2 && "invalid region index");
return getOperands();
}
void RegionIfOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
// We always branch to the join region.
if (index.hasValue()) {
if (index.getValue() < 2)
regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
else
regions.push_back(RegionSuccessor(getResults()));
return;
}
// The then and else regions are the entry regions of this op.
regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
}
#include "TestOpEnums.cpp.inc"
#include "TestOpInterfaces.cpp.inc"
#include "TestOpStructs.cpp.inc"
#include "TestTypeInterfaces.cpp.inc"
#define GET_OP_CLASSES
#include "TestOps.cpp.inc"