mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
[MLIR][Transform][SMT] Introduce transform.smt.constrain_params (#159450)
Introduces a Transform-dialect SMT-extension so that we can have an op to express constrains on Transform-dialect params, in particular when these params are knobs -- see transform.tune.knob -- and can hence be seen as symbolic variables. This op allows expressing joint constraints over multiple params/knobs together. While the op's semantics are clearly defined, per SMTLIB, the interpreted semantics -- i.e. the `apply()` method -- for now just defaults to failure. In the future we should support attaching an implementation so that users can Bring Your Own Solver and thereby control performance of interpreting the op. For now the main usage is to walk schedule IR and collect these constraints so that knobs can be rewritten to constants that satisfy the constraints.
This commit is contained in:
@@ -4,5 +4,6 @@ add_subdirectory(IR)
|
||||
add_subdirectory(IRDLExtension)
|
||||
add_subdirectory(LoopExtension)
|
||||
add_subdirectory(PDLExtension)
|
||||
add_subdirectory(SMTExtension)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(TuneExtension)
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
set(LLVM_TARGET_DEFINITIONS SMTExtensionOps.td)
|
||||
mlir_tablegen(SMTExtensionOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(SMTExtensionOps.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(MLIRTransformDialectSMTExtensionOpsIncGen)
|
||||
|
||||
add_mlir_doc(SMTExtensionOps SMTExtensionOps Dialects/ -gen-op-doc)
|
||||
@@ -0,0 +1,27 @@
|
||||
//===- SMTExtension.h - SMT extension for Transform dialect -----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
|
||||
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
|
||||
|
||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace transform {
|
||||
/// Registers the SMT extension of the Transform dialect in the given registry.
|
||||
void registerSMTExtension(DialectRegistry &dialectRegistry);
|
||||
} // namespace transform
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
|
||||
@@ -0,0 +1,21 @@
|
||||
//===- SMTExtensionOps.h - SMT extension for Transform dialect --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
|
||||
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
|
||||
|
||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
|
||||
@@ -0,0 +1,52 @@
|
||||
//===- SMTExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
|
||||
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
|
||||
|
||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
NoTerminator
|
||||
]> {
|
||||
let cppNamespace = [{ mlir::transform::smt }];
|
||||
|
||||
let summary = "Express contraints on params interpreted as symbolic values";
|
||||
let description = [{
|
||||
Allows expressing constraints on params using the SMT dialect.
|
||||
|
||||
Each Transform dialect param provided as an operand has a corresponding
|
||||
argument of SMT-type in the region. The SMT-Dialect ops in the region use
|
||||
these arguments as operands.
|
||||
|
||||
The semantics of this op is that all the ops in the region together express
|
||||
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
|
||||
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
|
||||
op succeeds.
|
||||
|
||||
---
|
||||
|
||||
TODO: currently the operational semantics per the Transform interpreter is
|
||||
to always fail. The intention is build out support for hooking in your own
|
||||
operational semantics so you can invoke your favourite solver to determine
|
||||
satisfiability of the corresponding constraint problem.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let assemblyFormat =
|
||||
"`(` $params `)` attr-dict `:` type(operands) $body";
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
|
||||
@@ -26,21 +26,26 @@ using namespace mlir::python::nanobind_adaptors;
|
||||
|
||||
static void populateDialectSMTSubmodule(nanobind::module_ &m) {
|
||||
|
||||
auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](const nb::object &, MlirContext context) {
|
||||
return mlirSMTTypeGetBool(context);
|
||||
},
|
||||
"cls"_a, "context"_a = nb::none());
|
||||
auto smtBoolType =
|
||||
mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
|
||||
.def_staticmethod(
|
||||
"get",
|
||||
[](MlirContext context) { return mlirSMTTypeGetBool(context); },
|
||||
"context"_a = nb::none());
|
||||
auto smtBitVectorType =
|
||||
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
|
||||
.def_classmethod(
|
||||
.def_staticmethod(
|
||||
"get",
|
||||
[](const nb::object &, int32_t width, MlirContext context) {
|
||||
[](int32_t width, MlirContext context) {
|
||||
return mlirSMTTypeGetBitVector(context, width);
|
||||
},
|
||||
"cls"_a, "width"_a, "context"_a = nb::none());
|
||||
"width"_a, "context"_a = nb::none());
|
||||
auto smtIntType =
|
||||
mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
|
||||
.def_staticmethod(
|
||||
"get",
|
||||
[](MlirContext context) { return mlirSMTTypeGetInt(context); },
|
||||
"context"_a = nb::none());
|
||||
|
||||
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
|
||||
bool indentLetBody) {
|
||||
|
||||
@@ -4,6 +4,7 @@ add_subdirectory(IR)
|
||||
add_subdirectory(IRDLExtension)
|
||||
add_subdirectory(LoopExtension)
|
||||
add_subdirectory(PDLExtension)
|
||||
add_subdirectory(SMTExtension)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(TuneExtension)
|
||||
add_subdirectory(Utils)
|
||||
|
||||
12
mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
Normal file
12
mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
add_mlir_dialect_library(MLIRTransformSMTExtension
|
||||
SMTExtension.cpp
|
||||
SMTExtensionOps.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRTransformDialectSMTExtensionOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransformDialect
|
||||
MLIRSMT
|
||||
)
|
||||
35
mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
Normal file
35
mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
//===- SMTExtension.cpp - SMT extension for the Transform dialect ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transform op registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class SMTExtension : public transform::TransformDialectExtension<SMTExtension> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SMTExtension)
|
||||
|
||||
SMTExtension() {
|
||||
registerTransformOps<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::transform::registerSMTExtension(DialectRegistry &dialectRegistry) {
|
||||
dialectRegistry.addExtensions<SMTExtension>();
|
||||
}
|
||||
55
mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
Normal file
55
mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
//===- SMTExtensionOps.cpp - SMT extension for the Transform dialect ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
|
||||
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstrainParamsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void transform::smt::ConstrainParamsOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getParamsMutable(), effects);
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
// TODO: Proper operational semantics are to check the SMT problem in the body
|
||||
// with a SMT solver with the arguments of the body constrained to the
|
||||
// values passed into the op. Success or failure is then determined by
|
||||
// the solver's result.
|
||||
// One way to support this is to just promise the TransformOpInterface
|
||||
// and allow for users to attach their own implementation, which would,
|
||||
// e.g., translate the ops to SMTLIB and hand that over to the user's
|
||||
// favourite solver. This requires changes to the dialect's verifier.
|
||||
return emitDefiniteFailure() << "op does not have interpreted semantics yet";
|
||||
}
|
||||
|
||||
LogicalResult transform::smt::ConstrainParamsOp::verify() {
|
||||
if (getOperands().size() != getBody().getNumArguments())
|
||||
return emitOpError(
|
||||
"must have the same number of block arguments as operands");
|
||||
|
||||
for (auto &op : getBody().getOps()) {
|
||||
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
|
||||
return emitOpError(
|
||||
"ops contained in region should belong to SMT-dialect");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -53,6 +53,7 @@
|
||||
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
|
||||
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
|
||||
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
|
||||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
|
||||
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
|
||||
@@ -108,6 +109,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
|
||||
transform::registerIRDLExtension(registry);
|
||||
transform::registerLoopExtension(registry);
|
||||
transform::registerPDLExtension(registry);
|
||||
transform::registerSMTExtension(registry);
|
||||
transform::registerTuneExtension(registry);
|
||||
vector::registerTransformDialectExtension(registry);
|
||||
arm_neon::registerTransformDialectExtension(registry);
|
||||
|
||||
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME transform_pdl_extension)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TransformSMTExtensionOps.td
|
||||
SOURCES
|
||||
dialects/transform/smt.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME transform_smt_extension)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
|
||||
19
mlir/python/mlir/dialects/TransformSMTExtensionOps.td
Normal file
19
mlir/python/mlir/dialects/TransformSMTExtensionOps.td
Normal file
@@ -0,0 +1,19 @@
|
||||
//===-- TransformSMTExtensionOps.td - Binding entry point --*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Entry point of the generated Python bindings for the SMT extension of the
|
||||
// Transform dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
|
||||
#define PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
|
||||
|
||||
include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td"
|
||||
|
||||
#endif // PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
|
||||
@@ -3,6 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._smt_ops_gen import *
|
||||
from ._smt_enum_gen import *
|
||||
|
||||
from .._mlir_libs._mlirDialectsSMT import *
|
||||
from ..extras.meta import region_op
|
||||
|
||||
38
mlir/python/mlir/dialects/transform/smt.py
Normal file
38
mlir/python/mlir/dialects/transform/smt.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from ...ir import Type, Block
|
||||
from .._transform_smt_extension_ops_gen import *
|
||||
from .._transform_smt_extension_ops_gen import _Dialect
|
||||
from ...dialects import transform
|
||||
|
||||
try:
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ConstrainParamsOp(ConstrainParamsOp):
|
||||
def __init__(
|
||||
self,
|
||||
params: Sequence[transform.AnyParamType],
|
||||
arg_types: Sequence[Type],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if len(params) != len(arg_types):
|
||||
raise ValueError(f"{params=} not same length as {arg_types=}")
|
||||
super().__init__(
|
||||
params,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
self.regions[0].blocks.append(*arg_types)
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
30
mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
Normal file
30
mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
Normal file
@@ -0,0 +1,30 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
|
||||
|
||||
// CHECK-LABEL: @constraint_not_using_smt_ops
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
|
||||
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
|
||||
^bb0(%param_as_smt_var: !smt.int):
|
||||
%c4 = arith.constant 4 : i32
|
||||
// This is the kind of thing one might think works:
|
||||
//arith.remsi %param_as_smt_var, %c4 : i32
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @operands_not_one_to_one_with_vars
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
// expected-error@below {{must have the same number of block arguments as operands}}
|
||||
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
|
||||
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
87
mlir/test/Dialect/Transform/test-smt-extension.mlir
Normal file
87
mlir/test/Dialect/Transform/test-smt-extension.mlir
Normal file
@@ -0,0 +1,87 @@
|
||||
// RUN: mlir-opt %s --split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @schedule_with_constrained_param
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @schedule_with_constrained_param(%arg0: !transform.any_op {transform.readonly}) {
|
||||
// CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
|
||||
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
|
||||
|
||||
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
|
||||
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
|
||||
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
|
||||
^bb0(%param_as_smt_var: !smt.int):
|
||||
// CHECK: %[[C0:.*]] = smt.int.constant 0
|
||||
%c0 = smt.int.constant 0
|
||||
// CHECK: %[[C43:.*]] = smt.int.constant 43
|
||||
%c43 = smt.int.constant 43
|
||||
// CHECK: %[[LOWER_BOUND:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
|
||||
%lower_bound = smt.int.cmp le %c0, %param_as_smt_var
|
||||
// CHECK: smt.assert %[[LOWER_BOUND]]
|
||||
smt.assert %lower_bound
|
||||
// CHECK: %[[UPPER_BOUND:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
|
||||
%upper_bound = smt.int.cmp le %param_as_smt_var, %c43
|
||||
// CHECK: smt.assert %[[UPPER_BOUND]]
|
||||
smt.assert %upper_bound
|
||||
}
|
||||
// NB: from here can rely on that 0 <= %param_as_param <= 43, even if its
|
||||
// definition changes.
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
|
||||
// CHECK: %[[PARAM_A:.*]] = transform.param.constant
|
||||
%param_a = transform.param.constant 4 -> !transform.param<i64>
|
||||
// CHECK: %[[PARAM_B:.*]] = transform.param.constant
|
||||
%param_b = transform.param.constant 16 -> !transform.param<i64>
|
||||
|
||||
// CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
|
||||
transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
|
||||
// CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
|
||||
^bb0(%var_a: !smt.int, %var_b: !smt.int):
|
||||
// CHECK: %[[C0:.*]] = smt.int.constant 0
|
||||
%c0 = smt.int.constant 0
|
||||
// CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
|
||||
%remainder = smt.int.mod %var_b, %var_a
|
||||
// CHECK: %[[EQ:.*]] = smt.eq %[[REMAINDER]], %[[C0]]
|
||||
%eq = smt.eq %remainder, %c0 : !smt.int
|
||||
// CHECK: smt.assert %[[EQ]]
|
||||
smt.assert %eq
|
||||
}
|
||||
// NB: from here can rely on that %param_a is a divisor of %param_b
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @schedule_with_param_as_a_bool
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
|
||||
// CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
|
||||
%param_as_param = transform.param.constant true -> !transform.any_param
|
||||
|
||||
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
|
||||
transform.smt.constrain_params(%param_as_param) : !transform.any_param {
|
||||
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
|
||||
^bb0(%param_as_smt_var: !smt.bool):
|
||||
// CHECK: %[[C0:.*]] = smt.int.constant 0
|
||||
%c0 = smt.int.constant 0
|
||||
// CHECK: %[[C1:.*]] = smt.int.constant 1
|
||||
%c1 = smt.int.constant 1
|
||||
// CHECK: %[[FALSEHOOD:.*]] = smt.eq %[[C0]], %[[C1]]
|
||||
%falsehood = smt.eq %c0, %c1 : !smt.int
|
||||
// CHECK: %[[TRUE_IFF_PARAM_IS:.*]] = smt.or %[[PARAM_AS_SMT_VAR]], %[[FALSEHOOD]]
|
||||
%true_iff_param_is = smt.or %param_as_smt_var, %falsehood
|
||||
// CHECK: smt.assert %[[TRUE_IFF_PARAM_IS]]
|
||||
smt.assert %true_iff_param_is
|
||||
}
|
||||
// NB: from here can rely on that %param_as_param holds true, even if its
|
||||
// definition changes.
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
50
mlir/test/python/dialects/transform_smt_ext.py
Normal file
50
mlir/test/python/dialects/transform_smt_ext.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir import ir
|
||||
from mlir.dialects import transform, smt
|
||||
from mlir.dialects.transform import smt as transform_smt
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
with ir.Context(), ir.Location.unknown():
|
||||
module = ir.Module.create()
|
||||
with ir.InsertionPoint(module.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
with ir.InsertionPoint(sequence.body):
|
||||
f(sequence.bodyTarget)
|
||||
transform.YieldOp()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testConstrainParamsOp
|
||||
@run
|
||||
def testConstrainParamsOp(target):
|
||||
dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
|
||||
# CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
|
||||
symbolic_value = transform.ParamConstantOp(
|
||||
transform.AnyParamType.get(), dummy_value
|
||||
)
|
||||
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
|
||||
constrain_params = transform_smt.ConstrainParamsOp(
|
||||
[symbolic_value], [smt.IntType.get()]
|
||||
)
|
||||
# CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
|
||||
with ir.InsertionPoint(constrain_params.body):
|
||||
# CHECK: %[[C0:.*]] = smt.int.constant 0
|
||||
c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
|
||||
# CHECK: %[[C43:.*]] = smt.int.constant 43
|
||||
c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
|
||||
# CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
|
||||
lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
|
||||
# CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
|
||||
ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43)
|
||||
# CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]]
|
||||
bounded = smt.AndOp([lb, ub])
|
||||
# CHECK: smt.assert %[[BOUNDED:.*]]
|
||||
smt.AssertOp(bounded)
|
||||
Reference in New Issue
Block a user