mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
[MLIR][Transform][SMT] Allow for declarative computations in schedules (#160895)
By allowing `transform.smt.constrain_params`'s region to yield SMT-vars, op instances can declare relationships, through constraints, on incoming params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it possible to declare that computations on params should be performed. The semantics are that the yielded SMT-vars should be from any valid satisfying assignment/model of the constraints in the region.
This commit is contained in:
@@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
|
||||
Pure,
|
||||
Terminator,
|
||||
ReturnLike,
|
||||
ParentOneOf<["smt::SolverOp", "smt::CheckOp",
|
||||
"smt::ForallOp", "smt::ExistsOp"]>,
|
||||
]> {
|
||||
let summary = "terminator operation for various regions of SMT operations";
|
||||
let arguments = (ins Variadic<AnyType>:$values);
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
|
||||
|
||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||
#include "mlir/Dialect/SMT/IR/SMTOps.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
@@ -16,7 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
NoTerminator
|
||||
SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
|
||||
]> {
|
||||
let cppNamespace = [{ mlir::transform::smt }];
|
||||
|
||||
@@ -24,14 +24,20 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
|
||||
let description = [{
|
||||
Allows expressing constraints on params using the SMT dialect.
|
||||
|
||||
Each Transform dialect param provided as an operand has a corresponding
|
||||
Each Transform-dialect param provided as an operand has a corresponding
|
||||
argument of SMT-type in the region. The SMT-Dialect ops in the region use
|
||||
these arguments as operands.
|
||||
these params-as-SMT-vars as operands, thereby expressing relevant
|
||||
constraints on their allowed values.
|
||||
|
||||
Computations w.r.t. passed-in params can also be expressed through the
|
||||
region's SMT-ops. Namely, the constraints express relationships to other
|
||||
SMT-variables which can then be yielded from the region (with `smt.yield`).
|
||||
|
||||
The semantics of this op is that all the ops in the region together express
|
||||
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
|
||||
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
|
||||
op succeeds.
|
||||
op succeeds and any one satisfying assignment is used to map the
|
||||
SMT-variables yielded in the region to `transform.param`s.
|
||||
|
||||
---
|
||||
|
||||
@@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
|
||||
let results = (outs Variadic<TransformParamTypeInterface>:$results);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let assemblyFormat =
|
||||
"`(` $params `)` attr-dict `:` type(operands) $body";
|
||||
"`(` $params `)` attr-dict `:` functional-type(operands, results) $body";
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
|
||||
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
|
||||
#include "mlir/Dialect/SMT/IR/SMTOps.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -23,6 +23,7 @@ using namespace mlir;
|
||||
void transform::smt::ConstrainParamsOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getParamsMutable(), effects);
|
||||
producesHandle(getResults(), effects);
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
@@ -37,19 +38,95 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
|
||||
// and allow for users to attach their own implementation, which would,
|
||||
// e.g., translate the ops to SMTLIB and hand that over to the user's
|
||||
// favourite solver. This requires changes to the dialect's verifier.
|
||||
return emitDefiniteFailure() << "op does not have interpreted semantics yet";
|
||||
return emitSilenceableFailure(getLoc())
|
||||
<< "op does not have interpreted semantics yet";
|
||||
}
|
||||
|
||||
LogicalResult transform::smt::ConstrainParamsOp::verify() {
|
||||
auto yieldTerminator =
|
||||
dyn_cast<mlir::smt::YieldOp>(getRegion().front().back());
|
||||
if (!yieldTerminator)
|
||||
return emitOpError() << "expected '"
|
||||
<< mlir::smt::YieldOp::getOperationName()
|
||||
<< "' as terminator";
|
||||
|
||||
auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
|
||||
Type paramType, StringRef paramDesc,
|
||||
auto *atOp) -> InFlightDiagnostic {
|
||||
if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
|
||||
smtType))
|
||||
return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx
|
||||
<< " is expected to be either a !smt.bool, a "
|
||||
"!smt.int, or a !smt.bv";
|
||||
|
||||
assert(isa<TransformParamTypeInterface>(paramType) &&
|
||||
"ODS specifies params' type should implement param interface");
|
||||
if (isa<transform::AnyParamType>(paramType))
|
||||
return {}; // No further checks can be done.
|
||||
|
||||
// NB: This cast must succeed as long as the only implementors of
|
||||
// TransformParamTypeInterface are AnyParamType and ParamType.
|
||||
Type typeWrappedByParam = cast<ParamType>(paramType).getType();
|
||||
|
||||
if (isa<mlir::smt::IntType>(smtType)) {
|
||||
if (!isa<IntegerType>(typeWrappedByParam))
|
||||
return atOp->emitOpError()
|
||||
<< "the type of " << smtDesc << " #" << idx
|
||||
<< " is !smt.int though the corresponding " << paramDesc
|
||||
<< " type (" << paramType << ") is not wrapping an integer type";
|
||||
} else if (isa<mlir::smt::BoolType>(smtType)) {
|
||||
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
|
||||
if (!wrappedIntType || wrappedIntType.getWidth() != 1)
|
||||
return atOp->emitOpError()
|
||||
<< "the type of " << smtDesc << " #" << idx
|
||||
<< " is !smt.bool though the corresponding " << paramDesc
|
||||
<< " type (" << paramType << ") is not wrapping i1";
|
||||
} else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
|
||||
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
|
||||
if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth())
|
||||
return atOp->emitOpError()
|
||||
<< "the type of " << smtDesc << " #" << idx << " is " << smtType
|
||||
<< " though the corresponding " << paramDesc << " type ("
|
||||
<< paramType
|
||||
<< ") is not wrapping an integer type of the same bitwidth";
|
||||
}
|
||||
|
||||
return {};
|
||||
};
|
||||
|
||||
if (getOperands().size() != getBody().getNumArguments())
|
||||
return emitOpError(
|
||||
"must have the same number of block arguments as operands");
|
||||
|
||||
for (auto [idx, operandType, blockArgType] :
|
||||
llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) {
|
||||
InFlightDiagnostic typeCheckResult =
|
||||
checkTypes(idx, blockArgType, "block arg", operandType, "operand",
|
||||
/*atOp=*/this);
|
||||
if (LogicalResult(typeCheckResult).failed())
|
||||
return typeCheckResult;
|
||||
}
|
||||
|
||||
for (auto &op : getBody().getOps()) {
|
||||
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
|
||||
return emitOpError(
|
||||
"ops contained in region should belong to SMT-dialect");
|
||||
}
|
||||
|
||||
if (yieldTerminator->getNumOperands() != getNumResults())
|
||||
return yieldTerminator.emitOpError()
|
||||
<< "expected terminator to have as many operands as the parent op "
|
||||
"has results";
|
||||
|
||||
for (auto [idx, termOperandType, resultType] : llvm::enumerate(
|
||||
yieldTerminator->getOperands().getType(), getResultTypes())) {
|
||||
InFlightDiagnostic typeCheckResult =
|
||||
checkTypes(idx, termOperandType, "terminator operand",
|
||||
cast<transform::ParamType>(resultType), "result",
|
||||
/*atOp=*/&yieldTerminator);
|
||||
if (LogicalResult(typeCheckResult).failed())
|
||||
return typeCheckResult;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ except ImportError as e:
|
||||
class ConstrainParamsOp(ConstrainParamsOp):
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
params: Sequence[transform.AnyParamType],
|
||||
arg_types: Sequence[Type],
|
||||
loc=None,
|
||||
@@ -27,6 +28,7 @@ class ConstrainParamsOp(ConstrainParamsOp):
|
||||
if len(params) != len(arg_types):
|
||||
raise ValueError(f"{params=} not same length as {arg_types=}")
|
||||
super().__init__(
|
||||
results,
|
||||
params,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
@@ -36,3 +38,13 @@ class ConstrainParamsOp(ConstrainParamsOp):
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
def constrain_params(
|
||||
results: Sequence[Type],
|
||||
params: Sequence[transform.AnyParamType],
|
||||
arg_types: Sequence[Type],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
|
||||
|
||||
// CHECK-LABEL: @constraint_not_using_smt_ops
|
||||
// CHECK-LABEL: @incorrect terminator
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
|
||||
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
|
||||
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
|
||||
// expected-error@below {{op expected 'smt.yield' as terminator}}
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
|
||||
^bb0(%param_as_smt_var: !smt.int):
|
||||
%c4 = arith.constant 4 : i32
|
||||
// This is the kind of thing one might think works:
|
||||
//arith.remsi %param_as_smt_var, %c4 : i32
|
||||
transform.yield
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
@@ -22,9 +20,117 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
// expected-error@below {{must have the same number of block arguments as operands}}
|
||||
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
|
||||
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @constraint_not_using_smt_ops
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
|
||||
^bb0(%param_as_smt_var: !smt.int):
|
||||
%c4 = arith.constant 4 : i32
|
||||
// This is the kind of thing one might think works:
|
||||
//arith.remsi %param_as_smt_var, %c4 : i32
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @results_not_one_to_one_with_vars
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param<i64>, !transform.param<i64>) -> () {
|
||||
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
|
||||
// expected-error@below {{expected terminator to have as many operands as the parent op has results}}
|
||||
smt.yield %param_as_smt_var : !smt.int
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @non_smt_type_block_args
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @non_smt_type_block_args(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
|
||||
// expected-error@below {{the type of block arg #0 is expected to be either a !smt.bool, a !smt.int, or a !smt.bv}}
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i8>) {
|
||||
^bb0(%param_as_smt_var: !transform.param<i8>):
|
||||
smt.yield %param_as_smt_var : !transform.param<i8>
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mismatched_arg_type_bool
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @mismatched_arg_type_bool(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
// expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1}}
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
|
||||
^bb0(%param_as_smt_var: !smt.bool):
|
||||
smt.yield %param_as_smt_var : !smt.bool
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mismatched_arg_type_bitvector
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @mismatched_arg_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
// expected-error@below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
|
||||
^bb0(%param_as_smt_var: !smt.bv<8>):
|
||||
smt.yield %param_as_smt_var : !smt.bv<8>
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mismatched_result_type_bool
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @mismatched_result_type_bool(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 1 -> !transform.param<i1>
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> (!transform.param<i64>) {
|
||||
^bb0(%param_as_smt_var: !smt.bool):
|
||||
// expected-error@below {{the type of terminator operand #0 is !smt.bool though the corresponding result type ('!transform.param<i64>') is not wrapping i1}}
|
||||
smt.yield %param_as_smt_var : !smt.bool
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mismatched_result_type_bitvector
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @mismatched_result_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i64>) {
|
||||
^bb0(%param_as_smt_var: !smt.bv<8>):
|
||||
// expected-error@below {{the type of terminator operand #0 is '!smt.bv<8>' though the corresponding result type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
|
||||
smt.yield %param_as_smt_var : !smt.bv<8>
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ module attributes {transform.with_named_sequence} {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
|
||||
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
|
||||
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
|
||||
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
|
||||
^bb0(%param_as_smt_var: !smt.int):
|
||||
// CHECK: %[[C0:.*]] = smt.int.constant 0
|
||||
@@ -31,18 +31,20 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
|
||||
// CHECK-LABEL: @schedule_with_constraint_on_multiple_params_returning_computed_value
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
|
||||
transform.named_sequence @schedule_with_constraint_on_multiple_params_returning_computed_value(%arg0: !transform.any_op {transform.readonly}) {
|
||||
// CHECK: %[[PARAM_A:.*]] = transform.param.constant
|
||||
%param_a = transform.param.constant 4 -> !transform.param<i64>
|
||||
// CHECK: %[[PARAM_B:.*]] = transform.param.constant
|
||||
%param_b = transform.param.constant 16 -> !transform.param<i64>
|
||||
%param_b = transform.param.constant 32 -> !transform.param<i64>
|
||||
|
||||
// CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
|
||||
transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
|
||||
%divisor = transform.smt.constrain_params(%param_a, %param_b) : (!transform.param<i64>, !transform.param<i64>) -> (!transform.param<i64>) {
|
||||
// CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
|
||||
^bb0(%var_a: !smt.int, %var_b: !smt.int):
|
||||
// CHECK: %[[DIV:.*]] = smt.int.div %[[VAR_B]], %[[VAR_A]]
|
||||
%divisor = smt.int.div %var_b, %var_a
|
||||
// CHECK: %[[C0:.*]] = smt.int.constant 0
|
||||
%c0 = smt.int.constant 0
|
||||
// CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
|
||||
@@ -51,8 +53,11 @@ module attributes {transform.with_named_sequence} {
|
||||
%eq = smt.eq %remainder, %c0 : !smt.int
|
||||
// CHECK: smt.assert %[[EQ]]
|
||||
smt.assert %eq
|
||||
// CHECK: smt.yield %[[DIV]]
|
||||
smt.yield %divisor : !smt.int
|
||||
}
|
||||
// NB: from here can rely on that %param_a is a divisor of %param_b
|
||||
// NB: from here can rely on that %param_a is a divisor of %param_b and
|
||||
// that the relevant factor, 8, got associated to %divisor.
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
@@ -63,10 +68,10 @@ module attributes {transform.with_named_sequence} {
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
|
||||
// CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
|
||||
%param_as_param = transform.param.constant true -> !transform.any_param
|
||||
%param_as_param = transform.param.constant true -> !transform.param<i1>
|
||||
|
||||
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
|
||||
transform.smt.constrain_params(%param_as_param) : !transform.any_param {
|
||||
transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> () {
|
||||
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
|
||||
^bb0(%param_as_smt_var: !smt.bool):
|
||||
// CHECK: %[[C0:.*]] = smt.int.constant 0
|
||||
|
||||
@@ -25,26 +25,44 @@ def run(f):
|
||||
# CHECK-LABEL: TEST: testConstrainParamsOp
|
||||
@run
|
||||
def testConstrainParamsOp(target):
|
||||
dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
|
||||
c42_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
|
||||
# CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
|
||||
symbolic_value = transform.ParamConstantOp(
|
||||
transform.AnyParamType.get(), dummy_value
|
||||
symbolic_value_as_param = transform.ParamConstantOp(
|
||||
transform.AnyParamType.get(), c42_attr
|
||||
)
|
||||
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
|
||||
constrain_params = transform_smt.ConstrainParamsOp(
|
||||
[symbolic_value], [smt.IntType.get()]
|
||||
[], [symbolic_value_as_param], [smt.IntType.get()]
|
||||
)
|
||||
# CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
|
||||
with ir.InsertionPoint(constrain_params.body):
|
||||
symbolic_value_as_smt_var = constrain_params.body.arguments[0]
|
||||
# CHECK: %[[C0:.*]] = smt.int.constant 0
|
||||
c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
|
||||
# CHECK: %[[C43:.*]] = smt.int.constant 43
|
||||
c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
|
||||
# CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
|
||||
lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
|
||||
lb = smt.IntCmpOp(smt.IntPredicate.le, c0, symbolic_value_as_smt_var)
|
||||
# CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
|
||||
ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43)
|
||||
ub = smt.IntCmpOp(smt.IntPredicate.le, symbolic_value_as_smt_var, c43)
|
||||
# CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]]
|
||||
bounded = smt.AndOp([lb, ub])
|
||||
# CHECK: smt.assert %[[BOUNDED:.*]]
|
||||
smt.AssertOp(bounded)
|
||||
smt.YieldOp([])
|
||||
|
||||
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
|
||||
compute_with_params = transform_smt.ConstrainParamsOp(
|
||||
[transform.ParamType.get(ir.IntegerType.get_signless(32))],
|
||||
[symbolic_value_as_param],
|
||||
[smt.IntType.get()],
|
||||
)
|
||||
# CHECK-NEXT: ^bb{{.*}}(%[[SMT_SYMB:.*]]: !smt.int):
|
||||
with ir.InsertionPoint(compute_with_params.body):
|
||||
symbolic_value_as_smt_var = compute_with_params.body.arguments[0]
|
||||
# CHECK: %[[TWICE:.*]] = smt.int.add %[[SMT_SYMB]], %[[SMT_SYMB]]
|
||||
twice_symb = smt.IntAddOp(
|
||||
[symbolic_value_as_smt_var, symbolic_value_as_smt_var]
|
||||
)
|
||||
# CHECK: smt.yield %[[TWICE]]
|
||||
smt.YieldOp([twice_symb])
|
||||
|
||||
Reference in New Issue
Block a user