[mlir] verify against nullptr payload in transform dialect

When establishing the correspondence between transform values and
payload operations or parameters, check that the latter are non-null and
report errors. This was previously allowed for exotic cases of partially
successfull transformations with "apply each" trait, but was dangerous.
The "apply each" implementation was reworked to remove the need for this
functionality, so this can now be hardned to avoid null pointer
dereferences.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141142
This commit is contained in:
Alex Zinenko
2023-01-09 14:01:25 +01:00
parent 499bf67208
commit 984c2c8cb3
5 changed files with 76 additions and 12 deletions

View File

@@ -832,36 +832,30 @@ applyTransformToEach(TransformOpTy transformOp, ArrayRef<Operation *> targets,
SmallVector<Diagnostic> silenceableStack;
unsigned expectedNumResults = transformOp->getNumResults();
for (Operation *target : targets) {
// Emplace back a placeholder for the returned new ops and params.
// This is filled with `expectedNumResults` if the op fails to apply.
ApplyToEachResultList placeholder;
placeholder.reserve(expectedNumResults);
results.push_back(std::move(placeholder));
auto specificOp = dyn_cast<OpTy>(target);
if (!specificOp) {
Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error);
diag << "transform applied to the wrong op kind";
diag.attachNote(target->getLoc()) << "when applied to this op";
// Producing `expectedNumResults` nullptr is a silenceableFailure mode.
// TODO: encode this implicit `expectedNumResults` nullptr ==
// silenceableFailure with a proper trait.
results.back().assign(expectedNumResults, nullptr);
silenceableStack.push_back(std::move(diag));
continue;
}
ApplyToEachResultList partialResults;
partialResults.reserve(expectedNumResults);
Location specificOpLoc = specificOp->getLoc();
DiagnosedSilenceableFailure res =
transformOp.applyToOne(specificOp, results.back(), state);
transformOp.applyToOne(specificOp, partialResults, state);
if (res.isDefiniteFailure() ||
failed(detail::checkApplyToOne(transformOp, specificOpLoc,
results.back()))) {
partialResults))) {
return DiagnosedSilenceableFailure::definiteFailure();
}
if (res.isSilenceableFailure())
res.takeDiagnostics(silenceableStack);
else
results.push_back(std::move(partialResults));
}
if (!silenceableStack.empty()) {
return DiagnosedSilenceableFailure::silenceableFailure(

View File

@@ -80,6 +80,13 @@ transform::TransformState::setPayloadOps(Value value,
assert(!value.getType().isa<TransformParamTypeInterface>() &&
"cannot associate payload ops with a value of parameter type");
for (Operation *target : targets) {
if (target)
continue;
return emitError(value.getLoc())
<< "attempting to assign a null payload op to this transform value";
}
auto iface = value.getType().cast<TransformHandleTypeInterface>();
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), targets);
@@ -105,6 +112,13 @@ LogicalResult transform::TransformState::setParams(Value value,
ArrayRef<Param> params) {
assert(value != nullptr && "attempting to set params for a null value");
for (Attribute attr : params) {
if (attr)
continue;
return emitError(value.getLoc())
<< "attempting to assign a null parameter to this transform value";
}
auto valueType = value.getType().dyn_cast<TransformParamTypeInterface>();
assert(value &&
"cannot associate parameter with a value of non-parameter type");

View File

@@ -1024,3 +1024,19 @@ transform.sequence failures(propagate) {
{ second_result_is_handle }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
}
// -----
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{attempting to assign a null payload op to this transform value}}
%0 = transform.test_produce_null_payload : !transform.any_op
}
// -----
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{attempting to assign a null parameter to this transform value}}
%0 = transform.test_produce_null_param : !transform.param<i64>
}

View File

@@ -458,6 +458,28 @@ mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceNullPayloadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::producesHandle(getOut(), effects);
}
DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
SmallVector<Operation *, 1> null({nullptr});
results.set(getOut().cast<OpResult>(), null);
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceNullParamOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
DiagnosedSilenceableFailure
mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
results.setParams(getOut().cast<OpResult>(), Attribute());
return DiagnosedSilenceableFailure::success();
}
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL

View File

@@ -334,4 +334,22 @@ def TestProduceTransformParamOrForwardOperandOp
}];
}
def TestProduceNullPayloadOp
: Op<Transform_Dialect, "test_produce_null_payload",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let results = (outs TransformHandleTypeInterface:$out);
let assemblyFormat = "attr-dict `:` type($out)";
let cppNamespace = "::mlir::test";
}
def TestProduceNullParamOp
: Op<Transform_Dialect, "test_produce_null_param",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let results = (outs TransformParamTypeInterface:$out);
let assemblyFormat = "attr-dict `:` type($out)";
let cppNamespace = "::mlir::test";
}
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD