[MLIR] Modify Partial op conversion mode to optionally track all non-legalizable operations.

There are three op conversion modes: Partial, Full, and Analysis. This change modifies the Partial mode to optionally take a set of non-legalizable ops. If this parameter is specified, all ops that are not legalizable (i.e. would cause full conversion to fail) are tracked throughout the partial legalization.

Differential Revision: https://reviews.llvm.org/D78788
This commit is contained in:
Lucy Fox
2020-04-30 09:47:19 -07:00
parent 9fc0e7c1aa
commit 8de482ea9a
4 changed files with 72 additions and 26 deletions

View File

@@ -660,20 +660,25 @@ private:
/// ConversionPatternRewriter, to see what additional constraints are imposed on
/// the use of the PatternRewriter.
/// Apply a partial conversion on the given operations, and all nested
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
/// returns failure if there are unreachable blocks in any of the regions nested
/// within 'ops'. If 'converter' is provided, the signatures of blocks and
/// regions are also converted.
/// returns failure if there ops explicitly marked as illegal. If `converter` is
/// provided, the signatures of blocks and regions are also converted.
/// If an `unconvertedOps` set is provided, all operations that are found not
/// to be legalizable to the given `target` are placed within that set. (Note
/// that if there is an op explicitly marked as illegal, the conversion
/// terminates and the `unconvertedOps` set will not necessarily be complete.)
LLVM_NODISCARD LogicalResult
applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
const OwningRewritePatternList &patterns,
TypeConverter *converter = nullptr);
TypeConverter *converter = nullptr,
DenseSet<Operation *> *unconvertedOps = nullptr);
LLVM_NODISCARD LogicalResult
applyPartialConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
TypeConverter *converter = nullptr);
TypeConverter *converter = nullptr,
DenseSet<Operation *> *unconvertedOps = nullptr);
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation

View File

@@ -1541,9 +1541,8 @@ struct OperationConverter {
explicit OperationConverter(ConversionTarget &target,
const OwningRewritePatternList &patterns,
OpConversionMode mode,
DenseSet<Operation *> *legalizableOps = nullptr)
: opLegalizer(target, patterns), mode(mode),
legalizableOps(legalizableOps) {}
DenseSet<Operation *> *trackedOps = nullptr)
: opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops,
@@ -1563,9 +1562,11 @@ private:
/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
/// A set of pre-existing operations that were found to be legalizable to the
/// target. This field is only used when mode == OpConversionMode::Analysis.
DenseSet<Operation *> *legalizableOps;
/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
/// this is populated with ops found to be legalizable to the target.
/// When mode == OpConversionMode::Partial, this is populated with ops found
/// *not* to be legalizable to the target.
DenseSet<Operation *> *trackedOps;
};
} // end anonymous namespace
@@ -1594,17 +1595,22 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
/// Partial conversions allow conversions to fail iff the operation was not
/// explicitly marked as illegal.
if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
/// explicitly marked as illegal. If the user provided a nonlegalizableOps
/// set, non-legalizable ops are included.
if (mode == OpConversionMode::Partial) {
if (opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
if (trackedOps)
trackedOps->insert(op);
}
} else {
/// Analysis conversions don't fail if any operations fail to legalize,
/// they are only interested in the operations that were successfully
/// legalized.
if (mode == OpConversionMode::Analysis)
legalizableOps->insert(op);
trackedOps->insert(op);
// If legalization succeeded, convert the types any of the blocks within
// this operation.
@@ -1932,21 +1938,30 @@ auto ConversionTarget::getOpInfo(OperationName op) const
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
/// Apply a partial conversion on the given operations, and all nested
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize.
/// possible, ignoring operations that failed to legalize. This method only
/// returns failure if there ops explicitly marked as illegal. If `converter` is
/// provided, the signatures of blocks and regions are also converted.
/// If an `unconvertedOps` set is provided, all operations that are found not
/// to be legalizable to the given `target` are placed within that set. (Note
/// that if there is an op explicitly marked as illegal, the conversion
/// terminates and the `unconvertedOps` set will not necessarily be complete.)
LogicalResult mlir::applyPartialConversion(
ArrayRef<Operation *> ops, ConversionTarget &target,
const OwningRewritePatternList &patterns, TypeConverter *converter) {
OperationConverter opConverter(target, patterns, OpConversionMode::Partial);
const OwningRewritePatternList &patterns, TypeConverter *converter,
DenseSet<Operation *> *unconvertedOps) {
OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
unconvertedOps);
return opConverter.convertOperations(ops, converter);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
const OwningRewritePatternList &patterns,
TypeConverter *converter) {
TypeConverter *converter,
DenseSet<Operation *> *unconvertedOps) {
return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
converter);
converter, unconvertedOps);
}
/// Apply a complete conversion on the given operations, and all nested

View File

@@ -4,6 +4,7 @@
func @verifyDirectPattern() -> i32 {
// CHECK-NEXT: "test.legal_op_a"() {status = "Success"}
%result = "test.illegal_op_a"() : () -> (i32)
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return %result : i32
}
@@ -11,6 +12,7 @@ func @verifyDirectPattern() -> i32 {
func @verifyLargerBenefit() -> i32 {
// CHECK-NEXT: "test.legal_op_a"() {status = "Success"}
%result = "test.illegal_op_c"() : () -> (i32)
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return %result : i32
}
@@ -26,7 +28,9 @@ func @remap_input_1_to_1(%arg0: i64) {
// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
func @remap_call_1_to_1(%arg0: i64) {
// CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
// expected-remark@+1 {{op 'std.call' is not legalizable}}
call @remap_input_1_to_1(%arg0) : (i64) -> ()
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}
@@ -40,6 +44,7 @@ func @remap_input_1_to_N(%arg0: f32) -> f32 {
func @remap_input_1_to_N_remaining_use(%arg0: f32) {
// CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
// CHECK-NEXT: "work"([[CAST]]) : (f32) -> ()
// expected-remark@+1 {{op 'work' is not legalizable}}
"work"(%arg0) : (f32) -> ()
}
@@ -47,6 +52,7 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) {
func @remap_input_to_self(%arg0: index) {
// CHECK-NOT: test.cast
// CHECK: "work"
// expected-remark@+1 {{op 'work' is not legalizable}}
"work"(%arg0) : (index) -> ()
}
@@ -59,12 +65,14 @@ func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
// CHECK-LABEL: func @no_remap_nested
func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
// expected-remark@+1 {{op 'foo.region' is not legalizable}}
"foo.region"() ({
// CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
^bb0(%i0: i64, %unused: i16, %i1: i64):
// CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
"test.invalid"(%i0, %i1) : (i64, i64) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}
@@ -78,6 +86,7 @@ func @remap_moved_region_args() {
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}
@@ -91,6 +100,7 @@ func @remap_cloned_region_args() {
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) {legalizer.should_clone} : () -> ()
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}
@@ -102,6 +112,7 @@ func @remap_drop_region() {
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}
@@ -109,6 +120,7 @@ func @remap_drop_region() {
func @dropped_input_in_use(%arg: i16, %arg2: i64) {
// CHECK-NEXT: "test.cast"{{.*}} : () -> i16
// CHECK-NEXT: "work"{{.*}} : (i16)
// expected-remark@+1 {{op 'work' is not legalizable}}
"work"(%arg) : (i16) -> ()
}
@@ -117,6 +129,7 @@ func @up_to_date_replacement(%arg: i8) -> i8 {
// CHECK-NEXT: return
%repl_1 = "test.rewrite"(%arg) : (i8) -> i8
%repl_2 = "test.rewrite"(%repl_1) : (i8) -> i8
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return %repl_2 : i8
}
@@ -127,11 +140,13 @@ func @remove_foldable_op(%arg0 : i32) -> (i32) {
%0 = "test.op_with_region_fold"(%arg0) ({
"foo.op_with_region_terminator"() : () -> ()
}) : (i32) -> (i32)
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return %0 : i32
}
// CHECK-LABEL: @create_block
func @create_block() {
// expected-remark@+1 {{op 'test.container' is not legalizable}}
"test.container"() ({
// Check that we created a block with arguments.
// CHECK-NOT: test.create_block
@@ -140,6 +155,7 @@ func @create_block() {
"test.create_block"() : () -> ()
"test.finish"() : () -> ()
}) : () -> ()
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}
@@ -147,6 +163,7 @@ func @create_block() {
func @bounded_recursion() {
// CHECK: test.recursive_rewrite 0
test.recursive_rewrite 3
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}
@@ -188,13 +205,16 @@ func @fail_to_convert_region() {
// CHECK-LABEL: @create_illegal_block
func @create_illegal_block() {
// expected-remark@+1 {{op 'test.container' is not legalizable}}
"test.container"() ({
// Check that we can undo block creation, i.e. that the block was removed.
// CHECK: test.create_illegal_block
// CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
// expected-remark@+1 {{op 'test.create_illegal_block' is not legalizable}}
"test.create_illegal_block"() : () -> ()
"test.finish"() : () -> ()
}) : () -> ()
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}
@@ -202,6 +222,7 @@ func @create_illegal_block() {
// CHECK-LABEL: @undo_block_arg_replace
func @undo_block_arg_replace() {
// expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}}
"test.undo_block_arg_replace"() ({
^bb0(%arg0: i32):
// CHECK: ^bb0(%[[ARG:.*]]: i32):
@@ -209,5 +230,6 @@ func @undo_block_arg_replace() {
"test.return"(%arg0) : (i32) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'std.return' is not legalizable}}
return
}

View File

@@ -515,8 +515,12 @@ struct TestLegalizePatternDriver
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
(void)applyPartialConversion(getOperation(), target, patterns,
&converter);
DenseSet<Operation *> unlegalizedOps;
(void)applyPartialConversion(getOperation(), target, patterns, &converter,
&unlegalizedOps);
// Emit remarks for each legalizable operation.
for (auto *op : unlegalizedOps)
op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
return;
}