mirror of
https://github.com/intel/llvm.git
synced 2026-01-18 16:50:51 +08:00
[mlir][Transforms] Dialect Conversion: Simplify block conversion API (#94866)
This commit simplifies and improves documentation for the part of the `ConversionPatternRewriter` API that deals with signature conversions. There are now two public functions for signature conversion: * `applySignatureConversion` converts a single block signature. This function used to take a `Region *` (but converted only the entry block). It now takes a `Block *`. * `convertRegionTypes` converts all block signatures of a region. `convertNonEntryRegionTypes` is removed because it is not widely used and can easily be expressed with a call to `applySignatureConversion` inside a loop. (See `Detensorize.cpp` for an example.) Note: For consistency, `convertRegionTypes` could be renamed to `applySignatureConversion` (overload) in the future. (Or `applySignatureConversion` renamed to `convertBlockTypes`.) Also clarify when a type converter and/or signature conversion object is needed and for what purpose. Internal code refactoring (NFC) of `ConversionPatternRewriterImpl` (the part that deals with signature conversions). This part of the codebase was quite convoluted and unintuitive. From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC. Note for LLVM integration: When you see `applySignatureConversion(region, ...)`, replace with `applySignatureConversion(region->front(), ...)`. In the unlikely case that you see `convertNonEntryRegionTypes`, apply the same changes as this commit did to `Detensorize.cpp`. --------- Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
This commit is contained in:
committed by
GitHub
parent
65310f34d7
commit
52050f3ff3
@@ -372,19 +372,23 @@ class TypeConverter {
|
||||
From the perspective of type conversion, the types of block arguments are a bit
|
||||
special. Throughout the conversion process, blocks may move between regions of
|
||||
different operations. Given this, the conversion of the types for blocks must be
|
||||
done explicitly via a conversion pattern. To convert the types of block
|
||||
arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
|
||||
be invoked; `convertRegionTypes`. This hook uses a provided type converter to
|
||||
apply type conversions to all blocks within a given region, and all blocks that
|
||||
move into that region. As noted above, the conversions performed by this method
|
||||
use the argument materialization hook on the `TypeConverter`. This hook also
|
||||
takes an optional `TypeConverter::SignatureConversion` parameter that applies a
|
||||
custom conversion to the entry block of the region. The types of the entry block
|
||||
arguments are often tied semantically to details on the operation, e.g. func::FuncOp,
|
||||
AffineForOp, etc. To convert the signature of just the region entry block, and
|
||||
not any other blocks within the region, the `applySignatureConversion` hook may
|
||||
be used instead. A signature conversion, `TypeConverter::SignatureConversion`,
|
||||
can be built programmatically:
|
||||
done explicitly via a conversion pattern.
|
||||
|
||||
To convert the types of block arguments within a Region, a custom hook on the
|
||||
`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
|
||||
uses a provided type converter to apply type conversions to all blocks of a
|
||||
given region. As noted above, the conversions performed by this method use the
|
||||
argument materialization hook on the `TypeConverter`. This hook also takes an
|
||||
optional `TypeConverter::SignatureConversion` parameter that applies a custom
|
||||
conversion to the entry block of the region. The types of the entry block
|
||||
arguments are often tied semantically to the operation, e.g.,
|
||||
`func::FuncOp`, `AffineForOp`, etc.
|
||||
|
||||
To convert the signature of just one given block, the
|
||||
`applySignatureConversion` hook can be used.
|
||||
|
||||
A signature conversion, `TypeConverter::SignatureConversion`, can be built
|
||||
programmatically:
|
||||
|
||||
```c++
|
||||
class SignatureConversion {
|
||||
|
||||
@@ -247,7 +247,8 @@ public:
|
||||
/// Attempts a 1-1 type conversion, expecting the result type to be
|
||||
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
|
||||
/// and a null type on conversion or cast failure.
|
||||
template <typename TargetType> TargetType convertType(Type t) const {
|
||||
template <typename TargetType>
|
||||
TargetType convertType(Type t) const {
|
||||
return dyn_cast_or_null<TargetType>(convertType(t));
|
||||
}
|
||||
|
||||
@@ -661,42 +662,42 @@ class ConversionPatternRewriter final : public PatternRewriter {
|
||||
public:
|
||||
~ConversionPatternRewriter() override;
|
||||
|
||||
/// Apply a signature conversion to the entry block of the given region. This
|
||||
/// replaces the entry block with a new block containing the updated
|
||||
/// signature. The new entry block to the region is returned for convenience.
|
||||
/// If no block argument types are changing, the entry original block will be
|
||||
/// Apply a signature conversion to given block. This replaces the block with
|
||||
/// a new block containing the updated signature. The operations of the given
|
||||
/// block are inlined into the newly-created block, which is returned.
|
||||
///
|
||||
/// If no block argument types are changing, the original block will be
|
||||
/// left in place and returned.
|
||||
///
|
||||
/// If provided, `converter` will be used for any materializations.
|
||||
/// A signature converison must be provided. (Type converters can construct
|
||||
/// a signature conversion with `convertBlockSignature`.)
|
||||
///
|
||||
/// Optionally, a type converter can be provided to build materializations.
|
||||
/// Note: If no type converter was provided or the type converter does not
|
||||
/// specify any suitable argument/target materialization rules, the dialect
|
||||
/// conversion may fail to legalize unresolved materializations.
|
||||
Block *
|
||||
applySignatureConversion(Region *region,
|
||||
applySignatureConversion(Block *block,
|
||||
TypeConverter::SignatureConversion &conversion,
|
||||
const TypeConverter *converter = nullptr);
|
||||
|
||||
/// Convert the types of block arguments within the given region. This
|
||||
/// Apply a signature conversion to each block in the given region. This
|
||||
/// replaces each block with a new block containing the updated signature. If
|
||||
/// an updated signature would match the current signature, the respective
|
||||
/// block is left in place as is.
|
||||
/// block is left in place as is. (See `applySignatureConversion` for
|
||||
/// details.) The new entry block of the region is returned.
|
||||
///
|
||||
/// The entry block may have a special conversion if `entryConversion` is
|
||||
/// provided. On success, the new entry block to the region is returned for
|
||||
/// convenience. Otherwise, failure is returned.
|
||||
/// SignatureConversions are computed with the specified type converter.
|
||||
/// This function returns "failure" if the type converter failed to compute
|
||||
/// a SignatureConversion for at least one block.
|
||||
///
|
||||
/// Optionally, a special SignatureConversion can be specified for the entry
|
||||
/// block. This is because the types of the entry block arguments are often
|
||||
/// tied semantically to the operation.
|
||||
FailureOr<Block *> convertRegionTypes(
|
||||
Region *region, const TypeConverter &converter,
|
||||
TypeConverter::SignatureConversion *entryConversion = nullptr);
|
||||
|
||||
/// Convert the types of block arguments within the given region except for
|
||||
/// the entry region. This replaces each non-entry block with a new block
|
||||
/// containing the updated signature. If an updated signature would match the
|
||||
/// current signature, the respective block is left in place as is.
|
||||
///
|
||||
/// If special conversion behavior is needed for the non-entry blocks (for
|
||||
/// example, we need to convert only a subset of a BB arguments), such
|
||||
/// behavior can be specified in blockConversions.
|
||||
LogicalResult convertNonEntryRegionTypes(
|
||||
Region *region, const TypeConverter &converter,
|
||||
ArrayRef<TypeConverter::SignatureConversion> blockConversions);
|
||||
|
||||
/// Replace all the uses of the block argument `from` with value `to`.
|
||||
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
|
||||
signatureConverter.remapInput(0, newIndVar);
|
||||
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
|
||||
signatureConverter.remapInput(i, header->getArgument(i));
|
||||
body = rewriter.applySignatureConversion(&forOp.getRegion(),
|
||||
body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
|
||||
signatureConverter);
|
||||
|
||||
// Move the blocks from the forOp into the loopOp. This is the body of the
|
||||
|
||||
@@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.startOpModification(op);
|
||||
Region ®ion = op.getFunctionBody();
|
||||
SmallVector<TypeConverter::SignatureConversion, 2> conversions;
|
||||
|
||||
for (Block &block : llvm::drop_begin(region, 1)) {
|
||||
conversions.emplace_back(block.getNumArguments());
|
||||
TypeConverter::SignatureConversion &back = conversions.back();
|
||||
for (Block &block :
|
||||
llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
|
||||
TypeConverter::SignatureConversion conversion(
|
||||
/*numOrigInputs=*/block.getNumArguments());
|
||||
|
||||
for (BlockArgument blockArgument : block.getArguments()) {
|
||||
int idx = blockArgument.getArgNumber();
|
||||
|
||||
if (blockArgsToDetensor.count(blockArgument))
|
||||
back.addInputs(idx, {getTypeConverter()->convertType(
|
||||
block.getArgumentTypes()[idx])});
|
||||
conversion.addInputs(idx, {getTypeConverter()->convertType(
|
||||
block.getArgumentTypes()[idx])});
|
||||
else
|
||||
back.addInputs(idx, {block.getArgumentTypes()[idx]});
|
||||
conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
|
||||
}
|
||||
}
|
||||
|
||||
if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter,
|
||||
conversions))) {
|
||||
rewriter.cancelOpModification(op);
|
||||
return failure();
|
||||
rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
|
||||
}
|
||||
|
||||
rewriter.finalizeOpModification(op);
|
||||
|
||||
@@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
// Type Conversion
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Attempt to convert the signature of the given block, if successful a new
|
||||
/// block is returned containing the new arguments. Returns `block` if it did
|
||||
/// not require conversion.
|
||||
FailureOr<Block *> convertBlockSignature(
|
||||
ConversionPatternRewriter &rewriter, Block *block,
|
||||
const TypeConverter *converter,
|
||||
TypeConverter::SignatureConversion *conversion = nullptr);
|
||||
|
||||
/// Convert the types of non-entry block arguments within the given region.
|
||||
LogicalResult convertNonEntryRegionTypes(
|
||||
ConversionPatternRewriter &rewriter, Region *region,
|
||||
const TypeConverter &converter,
|
||||
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
|
||||
|
||||
/// Apply a signature conversion on the given region, using `converter` for
|
||||
/// materializations if not null.
|
||||
Block *
|
||||
applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
|
||||
TypeConverter::SignatureConversion &conversion,
|
||||
const TypeConverter *converter);
|
||||
|
||||
/// Convert the types of block arguments within the given region.
|
||||
FailureOr<Block *>
|
||||
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
|
||||
@@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Conversion
|
||||
|
||||
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
|
||||
ConversionPatternRewriter &rewriter, Block *block,
|
||||
const TypeConverter *converter,
|
||||
TypeConverter::SignatureConversion *conversion) {
|
||||
if (conversion)
|
||||
return applySignatureConversion(rewriter, block, converter, *conversion);
|
||||
|
||||
// If a converter wasn't provided, and the block wasn't already converted,
|
||||
// there is nothing we can do.
|
||||
if (!converter)
|
||||
return failure();
|
||||
|
||||
// Try to convert the signature for the block with the provided converter.
|
||||
if (auto conversion = converter->convertBlockSignature(block))
|
||||
return applySignatureConversion(rewriter, block, converter, *conversion);
|
||||
return failure();
|
||||
}
|
||||
|
||||
Block *ConversionPatternRewriterImpl::applySignatureConversion(
|
||||
ConversionPatternRewriter &rewriter, Region *region,
|
||||
TypeConverter::SignatureConversion &conversion,
|
||||
const TypeConverter *converter) {
|
||||
if (!region->empty())
|
||||
return *convertBlockSignature(rewriter, ®ion->front(), converter,
|
||||
&conversion);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
|
||||
ConversionPatternRewriter &rewriter, Region *region,
|
||||
const TypeConverter &converter,
|
||||
@@ -1330,42 +1281,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
|
||||
if (region->empty())
|
||||
return nullptr;
|
||||
|
||||
if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
|
||||
return failure();
|
||||
|
||||
FailureOr<Block *> newEntry = convertBlockSignature(
|
||||
rewriter, ®ion->front(), &converter, entryConversion);
|
||||
return newEntry;
|
||||
}
|
||||
|
||||
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
|
||||
ConversionPatternRewriter &rewriter, Region *region,
|
||||
const TypeConverter &converter,
|
||||
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
|
||||
regionToConverter[region] = &converter;
|
||||
if (region->empty())
|
||||
return success();
|
||||
|
||||
// Convert the arguments of each block within the region.
|
||||
int blockIdx = 0;
|
||||
assert((blockConversions.empty() ||
|
||||
blockConversions.size() == region->getBlocks().size() - 1) &&
|
||||
"expected either to provide no SignatureConversions at all or to "
|
||||
"provide a SignatureConversion for each non-entry block");
|
||||
|
||||
// Convert the arguments of each non-entry block within the region.
|
||||
for (Block &block :
|
||||
llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
|
||||
TypeConverter::SignatureConversion *blockConversion =
|
||||
blockConversions.empty()
|
||||
? nullptr
|
||||
: const_cast<TypeConverter::SignatureConversion *>(
|
||||
&blockConversions[blockIdx++]);
|
||||
|
||||
if (failed(convertBlockSignature(rewriter, &block, &converter,
|
||||
blockConversion)))
|
||||
// Compute the signature for the block with the provided converter.
|
||||
std::optional<TypeConverter::SignatureConversion> conversion =
|
||||
converter.convertBlockSignature(&block);
|
||||
if (!conversion)
|
||||
return failure();
|
||||
// Convert the block with the computed signature.
|
||||
applySignatureConversion(rewriter, &block, &converter, *conversion);
|
||||
}
|
||||
return success();
|
||||
|
||||
// Convert the entry block. If an entry signature conversion was provided,
|
||||
// use that one. Otherwise, compute the signature with the type converter.
|
||||
if (entryConversion)
|
||||
return applySignatureConversion(rewriter, ®ion->front(), &converter,
|
||||
*entryConversion);
|
||||
std::optional<TypeConverter::SignatureConversion> conversion =
|
||||
converter.convertBlockSignature(®ion->front());
|
||||
if (!conversion)
|
||||
return failure();
|
||||
return applySignatureConversion(rewriter, ®ion->front(), &converter,
|
||||
*conversion);
|
||||
}
|
||||
|
||||
Block *ConversionPatternRewriterImpl::applySignatureConversion(
|
||||
@@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
|
||||
}
|
||||
|
||||
Block *ConversionPatternRewriter::applySignatureConversion(
|
||||
Region *region, TypeConverter::SignatureConversion &conversion,
|
||||
Block *block, TypeConverter::SignatureConversion &conversion,
|
||||
const TypeConverter *converter) {
|
||||
assert(!impl->wasOpReplaced(region->getParentOp()) &&
|
||||
assert(!impl->wasOpReplaced(block->getParentOp()) &&
|
||||
"attempting to apply a signature conversion to a block within a "
|
||||
"replaced/erased op");
|
||||
return impl->applySignatureConversion(*this, region, conversion, converter);
|
||||
return impl->applySignatureConversion(*this, block, converter, conversion);
|
||||
}
|
||||
|
||||
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
|
||||
@@ -1693,16 +1631,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
|
||||
return impl->convertRegionTypes(*this, region, converter, entryConversion);
|
||||
}
|
||||
|
||||
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
|
||||
Region *region, const TypeConverter &converter,
|
||||
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
|
||||
assert(!impl->wasOpReplaced(region->getParentOp()) &&
|
||||
"attempting to apply a signature conversion to a block within a "
|
||||
"replaced/erased op");
|
||||
return impl->convertNonEntryRegionTypes(*this, region, converter,
|
||||
blockConversions);
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
|
||||
Value to) {
|
||||
LLVM_DEBUG({
|
||||
@@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
|
||||
// If the region of the block has a type converter, try to convert the block
|
||||
// directly.
|
||||
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
|
||||
if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
|
||||
std::optional<TypeConverter::SignatureConversion> conversion =
|
||||
converter->convertBlockSignature(block);
|
||||
if (!conversion) {
|
||||
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
|
||||
"block"));
|
||||
return failure();
|
||||
}
|
||||
impl.applySignatureConversion(rewriter, block, converter, *conversion);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -1516,8 +1516,9 @@ struct TestTestSignatureConversionNoConverter
|
||||
if (failed(
|
||||
converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
|
||||
return failure();
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&] { rewriter.applySignatureConversion(®ion, result); });
|
||||
rewriter.modifyOpInPlace(op, [&] {
|
||||
rewriter.applySignatureConversion(®ion.front(), result);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user