mirror of
https://github.com/intel/llvm.git
synced 2026-01-21 04:14:03 +08:00
[mlir] verify against nullptr payload in transform dialect
When establishing the correspondence between transform values and payload operations or parameters, check that the latter are non-null and report errors. This was previously allowed for exotic cases of partially successfull transformations with "apply each" trait, but was dangerous. The "apply each" implementation was reworked to remove the need for this functionality, so this can now be hardned to avoid null pointer dereferences. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D141142
This commit is contained in:
@@ -832,36 +832,30 @@ applyTransformToEach(TransformOpTy transformOp, ArrayRef<Operation *> targets,
|
||||
SmallVector<Diagnostic> silenceableStack;
|
||||
unsigned expectedNumResults = transformOp->getNumResults();
|
||||
for (Operation *target : targets) {
|
||||
// Emplace back a placeholder for the returned new ops and params.
|
||||
// This is filled with `expectedNumResults` if the op fails to apply.
|
||||
ApplyToEachResultList placeholder;
|
||||
placeholder.reserve(expectedNumResults);
|
||||
results.push_back(std::move(placeholder));
|
||||
|
||||
auto specificOp = dyn_cast<OpTy>(target);
|
||||
if (!specificOp) {
|
||||
Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error);
|
||||
diag << "transform applied to the wrong op kind";
|
||||
diag.attachNote(target->getLoc()) << "when applied to this op";
|
||||
// Producing `expectedNumResults` nullptr is a silenceableFailure mode.
|
||||
// TODO: encode this implicit `expectedNumResults` nullptr ==
|
||||
// silenceableFailure with a proper trait.
|
||||
results.back().assign(expectedNumResults, nullptr);
|
||||
silenceableStack.push_back(std::move(diag));
|
||||
continue;
|
||||
}
|
||||
|
||||
ApplyToEachResultList partialResults;
|
||||
partialResults.reserve(expectedNumResults);
|
||||
Location specificOpLoc = specificOp->getLoc();
|
||||
DiagnosedSilenceableFailure res =
|
||||
transformOp.applyToOne(specificOp, results.back(), state);
|
||||
transformOp.applyToOne(specificOp, partialResults, state);
|
||||
if (res.isDefiniteFailure() ||
|
||||
failed(detail::checkApplyToOne(transformOp, specificOpLoc,
|
||||
results.back()))) {
|
||||
partialResults))) {
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
|
||||
if (res.isSilenceableFailure())
|
||||
res.takeDiagnostics(silenceableStack);
|
||||
else
|
||||
results.push_back(std::move(partialResults));
|
||||
}
|
||||
if (!silenceableStack.empty()) {
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(
|
||||
|
||||
@@ -80,6 +80,13 @@ transform::TransformState::setPayloadOps(Value value,
|
||||
assert(!value.getType().isa<TransformParamTypeInterface>() &&
|
||||
"cannot associate payload ops with a value of parameter type");
|
||||
|
||||
for (Operation *target : targets) {
|
||||
if (target)
|
||||
continue;
|
||||
return emitError(value.getLoc())
|
||||
<< "attempting to assign a null payload op to this transform value";
|
||||
}
|
||||
|
||||
auto iface = value.getType().cast<TransformHandleTypeInterface>();
|
||||
DiagnosedSilenceableFailure result =
|
||||
iface.checkPayload(value.getLoc(), targets);
|
||||
@@ -105,6 +112,13 @@ LogicalResult transform::TransformState::setParams(Value value,
|
||||
ArrayRef<Param> params) {
|
||||
assert(value != nullptr && "attempting to set params for a null value");
|
||||
|
||||
for (Attribute attr : params) {
|
||||
if (attr)
|
||||
continue;
|
||||
return emitError(value.getLoc())
|
||||
<< "attempting to assign a null parameter to this transform value";
|
||||
}
|
||||
|
||||
auto valueType = value.getType().dyn_cast<TransformParamTypeInterface>();
|
||||
assert(value &&
|
||||
"cannot associate parameter with a value of non-parameter type");
|
||||
|
||||
@@ -1024,3 +1024,19 @@ transform.sequence failures(propagate) {
|
||||
{ second_result_is_handle }
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
// expected-error @below {{attempting to assign a null payload op to this transform value}}
|
||||
%0 = transform.test_produce_null_payload : !transform.any_op
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
// expected-error @below {{attempting to assign a null parameter to this transform value}}
|
||||
%0 = transform.test_produce_null_param : !transform.param<i64>
|
||||
}
|
||||
|
||||
@@ -458,6 +458,28 @@ mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void mlir::test::TestProduceNullPayloadOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
transform::producesHandle(getOut(), effects);
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
|
||||
transform::TransformResults &results, transform::TransformState &state) {
|
||||
SmallVector<Operation *, 1> null({nullptr});
|
||||
results.set(getOut().cast<OpResult>(), null);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void mlir::test::TestProduceNullParamOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
results.setParams(getOut().cast<OpResult>(), Attribute());
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Test extension of the Transform dialect. Registers additional ops and
|
||||
/// declares PDL as dependent dialect since the additional ops are using PDL
|
||||
|
||||
@@ -334,4 +334,22 @@ def TestProduceTransformParamOrForwardOperandOp
|
||||
}];
|
||||
}
|
||||
|
||||
def TestProduceNullPayloadOp
|
||||
: Op<Transform_Dialect, "test_produce_null_payload",
|
||||
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
let results = (outs TransformHandleTypeInterface:$out);
|
||||
let assemblyFormat = "attr-dict `:` type($out)";
|
||||
let cppNamespace = "::mlir::test";
|
||||
}
|
||||
|
||||
def TestProduceNullParamOp
|
||||
: Op<Transform_Dialect, "test_produce_null_param",
|
||||
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
let results = (outs TransformParamTypeInterface:$out);
|
||||
let assemblyFormat = "attr-dict `:` type($out)";
|
||||
let cppNamespace = "::mlir::test";
|
||||
}
|
||||
|
||||
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
|
||||
|
||||
Reference in New Issue
Block a user