[MLIR][Transform][SMT] Allow for declarative computations in schedules (#160895)

By allowing `transform.smt.constrain_params`'s region to yield SMT-vars,
op instances can declare relationships, through constraints, on incoming
params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it
possible to declare that computations on params should be performed.

The semantics are that the yielded SMT-vars should be from any valid
satisfying assignment/model of the constraints in the region.
This commit is contained in:
Rolf Morel
2025-10-19 00:48:23 +01:00
committed by GitHub
parent 34ed1dcf61
commit 9351ad638b
8 changed files with 256 additions and 32 deletions

View File

@@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
Pure,
Terminator,
ReturnLike,
ParentOneOf<["smt::SolverOp", "smt::CheckOp",
"smt::ForallOp", "smt::ExistsOp"]>,
]> {
let summary = "terminator operation for various regions of SMT operations";
let arguments = (ins Variadic<AnyType>:$values);

View File

@@ -10,6 +10,7 @@
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/SMT/IR/SMTOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"

View File

@@ -16,7 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
NoTerminator
SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
]> {
let cppNamespace = [{ mlir::transform::smt }];
@@ -24,14 +24,20 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
let description = [{
Allows expressing constraints on params using the SMT dialect.
Each Transform dialect param provided as an operand has a corresponding
Each Transform-dialect param provided as an operand has a corresponding
argument of SMT-type in the region. The SMT-Dialect ops in the region use
these arguments as operands.
these params-as-SMT-vars as operands, thereby expressing relevant
constraints on their allowed values.
Computations w.r.t. passed-in params can also be expressed through the
region's SMT-ops. Namely, the constraints express relationships to other
SMT-variables which can then be yielded from the region (with `smt.yield`).
The semantics of this op is that all the ops in the region together express
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
op succeeds.
op succeeds and any one satisfying assignment is used to map the
SMT-variables yielded in the region to `transform.param`s.
---
@@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
}];
let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
let results = (outs Variadic<TransformParamTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
"`(` $params `)` attr-dict `:` type(operands) $body";
"`(` $params `)` attr-dict `:` functional-type(operands, results) $body";
let hasVerifier = 1;
}

View File

@@ -8,8 +8,8 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/SMT/IR/SMTOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
using namespace mlir;
@@ -23,6 +23,7 @@ using namespace mlir;
void transform::smt::ConstrainParamsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getParamsMutable(), effects);
producesHandle(getResults(), effects);
}
DiagnosedSilenceableFailure
@@ -37,19 +38,95 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
// and allow for users to attach their own implementation, which would,
// e.g., translate the ops to SMTLIB and hand that over to the user's
// favourite solver. This requires changes to the dialect's verifier.
return emitDefiniteFailure() << "op does not have interpreted semantics yet";
return emitSilenceableFailure(getLoc())
<< "op does not have interpreted semantics yet";
}
LogicalResult transform::smt::ConstrainParamsOp::verify() {
auto yieldTerminator =
dyn_cast<mlir::smt::YieldOp>(getRegion().front().back());
if (!yieldTerminator)
return emitOpError() << "expected '"
<< mlir::smt::YieldOp::getOperationName()
<< "' as terminator";
auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
Type paramType, StringRef paramDesc,
auto *atOp) -> InFlightDiagnostic {
if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
smtType))
return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx
<< " is expected to be either a !smt.bool, a "
"!smt.int, or a !smt.bv";
assert(isa<TransformParamTypeInterface>(paramType) &&
"ODS specifies params' type should implement param interface");
if (isa<transform::AnyParamType>(paramType))
return {}; // No further checks can be done.
// NB: This cast must succeed as long as the only implementors of
// TransformParamTypeInterface are AnyParamType and ParamType.
Type typeWrappedByParam = cast<ParamType>(paramType).getType();
if (isa<mlir::smt::IntType>(smtType)) {
if (!isa<IntegerType>(typeWrappedByParam))
return atOp->emitOpError()
<< "the type of " << smtDesc << " #" << idx
<< " is !smt.int though the corresponding " << paramDesc
<< " type (" << paramType << ") is not wrapping an integer type";
} else if (isa<mlir::smt::BoolType>(smtType)) {
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
if (!wrappedIntType || wrappedIntType.getWidth() != 1)
return atOp->emitOpError()
<< "the type of " << smtDesc << " #" << idx
<< " is !smt.bool though the corresponding " << paramDesc
<< " type (" << paramType << ") is not wrapping i1";
} else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth())
return atOp->emitOpError()
<< "the type of " << smtDesc << " #" << idx << " is " << smtType
<< " though the corresponding " << paramDesc << " type ("
<< paramType
<< ") is not wrapping an integer type of the same bitwidth";
}
return {};
};
if (getOperands().size() != getBody().getNumArguments())
return emitOpError(
"must have the same number of block arguments as operands");
for (auto [idx, operandType, blockArgType] :
llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) {
InFlightDiagnostic typeCheckResult =
checkTypes(idx, blockArgType, "block arg", operandType, "operand",
/*atOp=*/this);
if (LogicalResult(typeCheckResult).failed())
return typeCheckResult;
}
for (auto &op : getBody().getOps()) {
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
return emitOpError(
"ops contained in region should belong to SMT-dialect");
}
if (yieldTerminator->getNumOperands() != getNumResults())
return yieldTerminator.emitOpError()
<< "expected terminator to have as many operands as the parent op "
"has results";
for (auto [idx, termOperandType, resultType] : llvm::enumerate(
yieldTerminator->getOperands().getType(), getResultTypes())) {
InFlightDiagnostic typeCheckResult =
checkTypes(idx, termOperandType, "terminator operand",
cast<transform::ParamType>(resultType), "result",
/*atOp=*/&yieldTerminator);
if (LogicalResult(typeCheckResult).failed())
return typeCheckResult;
}
return success();
}

View File

@@ -19,6 +19,7 @@ except ImportError as e:
class ConstrainParamsOp(ConstrainParamsOp):
def __init__(
self,
results: Sequence[Type],
params: Sequence[transform.AnyParamType],
arg_types: Sequence[Type],
loc=None,
@@ -27,6 +28,7 @@ class ConstrainParamsOp(ConstrainParamsOp):
if len(params) != len(arg_types):
raise ValueError(f"{params=} not same length as {arg_types=}")
super().__init__(
results,
params,
loc=loc,
ip=ip,
@@ -36,3 +38,13 @@ class ConstrainParamsOp(ConstrainParamsOp):
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
def constrain_params(
results: Sequence[Type],
params: Sequence[transform.AnyParamType],
arg_types: Sequence[Type],
loc=None,
ip=None,
):
return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip)

View File

@@ -1,15 +1,13 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
// CHECK-LABEL: @constraint_not_using_smt_ops
// CHECK-LABEL: @incorrect terminator
module attributes {transform.with_named_sequence} {
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
// expected-error@below {{op expected 'smt.yield' as terminator}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int):
%c4 = arith.constant 4 : i32
// This is the kind of thing one might think works:
//arith.remsi %param_as_smt_var, %c4 : i32
transform.yield
}
transform.yield
}
@@ -22,9 +20,117 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{must have the same number of block arguments as operands}}
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
}
transform.yield
}
}
// -----
// CHECK-LABEL: @constraint_not_using_smt_ops
module attributes {transform.with_named_sequence} {
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int):
%c4 = arith.constant 4 : i32
// This is the kind of thing one might think works:
//arith.remsi %param_as_smt_var, %c4 : i32
}
transform.yield
}
}
// -----
// CHECK-LABEL: @results_not_one_to_one_with_vars
module attributes {transform.with_named_sequence} {
transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param<i64>, !transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
// expected-error@below {{expected terminator to have as many operands as the parent op has results}}
smt.yield %param_as_smt_var : !smt.int
}
transform.yield
}
}
// -----
// CHECK-LABEL: @non_smt_type_block_args
module attributes {transform.with_named_sequence} {
transform.named_sequence @non_smt_type_block_args(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
// expected-error@below {{the type of block arg #0 is expected to be either a !smt.bool, a !smt.int, or a !smt.bv}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i8>) {
^bb0(%param_as_smt_var: !transform.param<i8>):
smt.yield %param_as_smt_var : !transform.param<i8>
}
transform.yield
}
}
// -----
// CHECK-LABEL: @mismatched_arg_type_bool
module attributes {transform.with_named_sequence} {
transform.named_sequence @mismatched_arg_type_bool(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
^bb0(%param_as_smt_var: !smt.bool):
smt.yield %param_as_smt_var : !smt.bool
}
transform.yield
}
}
// -----
// CHECK-LABEL: @mismatched_arg_type_bitvector
module attributes {transform.with_named_sequence} {
transform.named_sequence @mismatched_arg_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
^bb0(%param_as_smt_var: !smt.bv<8>):
smt.yield %param_as_smt_var : !smt.bv<8>
}
transform.yield
}
}
// -----
// CHECK-LABEL: @mismatched_result_type_bool
module attributes {transform.with_named_sequence} {
transform.named_sequence @mismatched_result_type_bool(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 1 -> !transform.param<i1>
transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> (!transform.param<i64>) {
^bb0(%param_as_smt_var: !smt.bool):
// expected-error@below {{the type of terminator operand #0 is !smt.bool though the corresponding result type ('!transform.param<i64>') is not wrapping i1}}
smt.yield %param_as_smt_var : !smt.bool
}
transform.yield
}
}
// -----
// CHECK-LABEL: @mismatched_result_type_bitvector
module attributes {transform.with_named_sequence} {
transform.named_sequence @mismatched_result_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i64>) {
^bb0(%param_as_smt_var: !smt.bv<8>):
// expected-error@below {{the type of terminator operand #0 is '!smt.bv<8>' though the corresponding result type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
smt.yield %param_as_smt_var : !smt.bv<8>
}
transform.yield
}
}

View File

@@ -7,7 +7,7 @@ module attributes {transform.with_named_sequence} {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
^bb0(%param_as_smt_var: !smt.int):
// CHECK: %[[C0:.*]] = smt.int.constant 0
@@ -31,18 +31,20 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
// CHECK-LABEL: @schedule_with_constraint_on_multiple_params_returning_computed_value
module attributes {transform.with_named_sequence} {
transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
transform.named_sequence @schedule_with_constraint_on_multiple_params_returning_computed_value(%arg0: !transform.any_op {transform.readonly}) {
// CHECK: %[[PARAM_A:.*]] = transform.param.constant
%param_a = transform.param.constant 4 -> !transform.param<i64>
// CHECK: %[[PARAM_B:.*]] = transform.param.constant
%param_b = transform.param.constant 16 -> !transform.param<i64>
%param_b = transform.param.constant 32 -> !transform.param<i64>
// CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
%divisor = transform.smt.constrain_params(%param_a, %param_b) : (!transform.param<i64>, !transform.param<i64>) -> (!transform.param<i64>) {
// CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
^bb0(%var_a: !smt.int, %var_b: !smt.int):
// CHECK: %[[DIV:.*]] = smt.int.div %[[VAR_B]], %[[VAR_A]]
%divisor = smt.int.div %var_b, %var_a
// CHECK: %[[C0:.*]] = smt.int.constant 0
%c0 = smt.int.constant 0
// CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
@@ -51,8 +53,11 @@ module attributes {transform.with_named_sequence} {
%eq = smt.eq %remainder, %c0 : !smt.int
// CHECK: smt.assert %[[EQ]]
smt.assert %eq
// CHECK: smt.yield %[[DIV]]
smt.yield %divisor : !smt.int
}
// NB: from here can rely on that %param_a is a divisor of %param_b
// NB: from here can rely on that %param_a is a divisor of %param_b and
// that the relevant factor, 8, got associated to %divisor.
transform.yield
}
}
@@ -63,10 +68,10 @@ module attributes {transform.with_named_sequence} {
module attributes {transform.with_named_sequence} {
transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
// CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
%param_as_param = transform.param.constant true -> !transform.any_param
%param_as_param = transform.param.constant true -> !transform.param<i1>
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
transform.smt.constrain_params(%param_as_param) : !transform.any_param {
transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> () {
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
^bb0(%param_as_smt_var: !smt.bool):
// CHECK: %[[C0:.*]] = smt.int.constant 0

View File

@@ -25,26 +25,44 @@ def run(f):
# CHECK-LABEL: TEST: testConstrainParamsOp
@run
def testConstrainParamsOp(target):
dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
c42_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
# CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
symbolic_value = transform.ParamConstantOp(
transform.AnyParamType.get(), dummy_value
symbolic_value_as_param = transform.ParamConstantOp(
transform.AnyParamType.get(), c42_attr
)
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
constrain_params = transform_smt.ConstrainParamsOp(
[symbolic_value], [smt.IntType.get()]
[], [symbolic_value_as_param], [smt.IntType.get()]
)
# CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
with ir.InsertionPoint(constrain_params.body):
symbolic_value_as_smt_var = constrain_params.body.arguments[0]
# CHECK: %[[C0:.*]] = smt.int.constant 0
c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
# CHECK: %[[C43:.*]] = smt.int.constant 43
c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
# CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
lb = smt.IntCmpOp(smt.IntPredicate.le, c0, symbolic_value_as_smt_var)
# CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43)
ub = smt.IntCmpOp(smt.IntPredicate.le, symbolic_value_as_smt_var, c43)
# CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]]
bounded = smt.AndOp([lb, ub])
# CHECK: smt.assert %[[BOUNDED:.*]]
smt.AssertOp(bounded)
smt.YieldOp([])
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
compute_with_params = transform_smt.ConstrainParamsOp(
[transform.ParamType.get(ir.IntegerType.get_signless(32))],
[symbolic_value_as_param],
[smt.IntType.get()],
)
# CHECK-NEXT: ^bb{{.*}}(%[[SMT_SYMB:.*]]: !smt.int):
with ir.InsertionPoint(compute_with_params.body):
symbolic_value_as_smt_var = compute_with_params.body.arguments[0]
# CHECK: %[[TWICE:.*]] = smt.int.add %[[SMT_SYMB]], %[[SMT_SYMB]]
twice_symb = smt.IntAddOp(
[symbolic_value_as_smt_var, symbolic_value_as_smt_var]
)
# CHECK: smt.yield %[[TWICE]]
smt.YieldOp([twice_symb])