mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[mlir][python] Normalize asm-printing IR behavior.
While working on an integration, I found a lot of inconsistencies on IR printing and verification. It turns out that we were: * Only doing "soft fail" verification on IR printing of Operation, not of a Module. * Failed verification was interacting badly with binary=True IR printing (causing a TypeError trying to pass an `str` to a `bytes` based handle). * For systematic integrations, it is often desirable to control verification yourself so that you can explicitly handle errors. This patch: * Trues up the "soft fail" semantics by having `Module.__str__` delegate to `Operation.__str__` vs having a shortcut implementation. * Fixes soft fail in the presence of binary=True (and adds an additional happy path test case to make sure the binary functionality works). * Adds an `assume_verified` boolean flag to the `print`/`get_asm` methods which disables internal verification, presupposing that the caller has taken care of it. It turns out that we had a number of tests which were generating illegal IR but it wasn't being caught because they were doing a print on the `Module` vs operation. All except two were trivially fixed: * linalg/ops.py : Had two tests for direct constructing a Matmul incorrectly. Fixing them made them just like the next two tests so just deleted (no need to test the verifier only at this level). * linalg/opdsl/emit_structured_generic.py : Hand coded conv and pooling tests appear to be using illegal shaped inputs/outputs, causing a verification failure. I just used the `assume_verified=` flag to restore the original behavior and left a TODO. Will get someone who owns that to fix it properly in a followup (would also be nice to break this file up into multiple test modules as it is hard to tell exactly what is failing). Notes to downstreams: * If, like some of our tests, you get verification failures after this patch, it is likely that your IR was always invalid and you will need to fix the root cause. To temporarily revert to prior (broken) behavior, replace calls like `print(module)` with `print(module.operation.get_asm(assume_verified=True))`. Differential Revision: https://reviews.llvm.org/D114680
This commit is contained in:
@@ -93,6 +93,13 @@ Args:
|
||||
use_local_Scope: Whether to print in a way that is more optimized for
|
||||
multi-threaded access but may not be consistent with how the overall
|
||||
module prints.
|
||||
assume_verified: By default, if not printing generic form, the verifier
|
||||
will be run and if it fails, generic form will be printed with a comment
|
||||
about failed verification. While a reasonable default for interactive use,
|
||||
for systematic use, it is often better for the caller to verify explicitly
|
||||
and report failures in a more robust fashion. Set this to True if doing this
|
||||
in order to avoid running a redundant verification. If the IR is actually
|
||||
invalid, behavior is undefined.
|
||||
)";
|
||||
|
||||
static const char kOperationGetAsmDocstring[] =
|
||||
@@ -828,14 +835,21 @@ void PyOperation::checkValid() const {
|
||||
void PyOperationBase::print(py::object fileObject, bool binary,
|
||||
llvm::Optional<int64_t> largeElementsLimit,
|
||||
bool enableDebugInfo, bool prettyDebugInfo,
|
||||
bool printGenericOpForm, bool useLocalScope) {
|
||||
bool printGenericOpForm, bool useLocalScope,
|
||||
bool assumeVerified) {
|
||||
PyOperation &operation = getOperation();
|
||||
operation.checkValid();
|
||||
if (fileObject.is_none())
|
||||
fileObject = py::module::import("sys").attr("stdout");
|
||||
|
||||
if (!printGenericOpForm && !mlirOperationVerify(operation)) {
|
||||
fileObject.attr("write")("// Verification failed, printing generic form\n");
|
||||
if (!assumeVerified && !printGenericOpForm &&
|
||||
!mlirOperationVerify(operation)) {
|
||||
std::string message("// Verification failed, printing generic form\n");
|
||||
if (binary) {
|
||||
fileObject.attr("write")(py::bytes(message));
|
||||
} else {
|
||||
fileObject.attr("write")(py::str(message));
|
||||
}
|
||||
printGenericOpForm = true;
|
||||
}
|
||||
|
||||
@@ -857,8 +871,8 @@ void PyOperationBase::print(py::object fileObject, bool binary,
|
||||
py::object PyOperationBase::getAsm(bool binary,
|
||||
llvm::Optional<int64_t> largeElementsLimit,
|
||||
bool enableDebugInfo, bool prettyDebugInfo,
|
||||
bool printGenericOpForm,
|
||||
bool useLocalScope) {
|
||||
bool printGenericOpForm, bool useLocalScope,
|
||||
bool assumeVerified) {
|
||||
py::object fileObject;
|
||||
if (binary) {
|
||||
fileObject = py::module::import("io").attr("BytesIO")();
|
||||
@@ -870,7 +884,8 @@ py::object PyOperationBase::getAsm(bool binary,
|
||||
/*enableDebugInfo=*/enableDebugInfo,
|
||||
/*prettyDebugInfo=*/prettyDebugInfo,
|
||||
/*printGenericOpForm=*/printGenericOpForm,
|
||||
/*useLocalScope=*/useLocalScope);
|
||||
/*useLocalScope=*/useLocalScope,
|
||||
/*assumeVerified=*/assumeVerified);
|
||||
|
||||
return fileObject.attr("getvalue")();
|
||||
}
|
||||
@@ -2149,12 +2164,9 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
kDumpDocstring)
|
||||
.def(
|
||||
"__str__",
|
||||
[](PyModule &self) {
|
||||
MlirOperation operation = mlirModuleGetOperation(self.get());
|
||||
PyPrintAccumulator printAccum;
|
||||
mlirOperationPrint(operation, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
return printAccum.join();
|
||||
[](py::object self) {
|
||||
// Defer to the operation's __str__.
|
||||
return self.attr("operation").attr("__str__")();
|
||||
},
|
||||
kOperationStrDunderDocstring);
|
||||
|
||||
@@ -2234,7 +2246,8 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
/*enableDebugInfo=*/false,
|
||||
/*prettyDebugInfo=*/false,
|
||||
/*printGenericOpForm=*/false,
|
||||
/*useLocalScope=*/false);
|
||||
/*useLocalScope=*/false,
|
||||
/*assumeVerified=*/false);
|
||||
},
|
||||
"Returns the assembly form of the operation.")
|
||||
.def("print", &PyOperationBase::print,
|
||||
@@ -2244,7 +2257,8 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
py::arg("enable_debug_info") = false,
|
||||
py::arg("pretty_debug_info") = false,
|
||||
py::arg("print_generic_op_form") = false,
|
||||
py::arg("use_local_scope") = false, kOperationPrintDocstring)
|
||||
py::arg("use_local_scope") = false,
|
||||
py::arg("assume_verified") = false, kOperationPrintDocstring)
|
||||
.def("get_asm", &PyOperationBase::getAsm,
|
||||
// Careful: Lots of arguments must match up with get_asm method.
|
||||
py::arg("binary") = false,
|
||||
@@ -2252,7 +2266,8 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
py::arg("enable_debug_info") = false,
|
||||
py::arg("pretty_debug_info") = false,
|
||||
py::arg("print_generic_op_form") = false,
|
||||
py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
|
||||
py::arg("use_local_scope") = false,
|
||||
py::arg("assume_verified") = false, kOperationGetAsmDocstring)
|
||||
.def(
|
||||
"verify",
|
||||
[](PyOperationBase &self) {
|
||||
|
||||
@@ -394,11 +394,13 @@ public:
|
||||
/// Implements the bound 'print' method and helps with others.
|
||||
void print(pybind11::object fileObject, bool binary,
|
||||
llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
|
||||
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
|
||||
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
|
||||
bool assumeVerified);
|
||||
pybind11::object getAsm(bool binary,
|
||||
llvm::Optional<int64_t> largeElementsLimit,
|
||||
bool enableDebugInfo, bool prettyDebugInfo,
|
||||
bool printGenericOpForm, bool useLocalScope);
|
||||
bool printGenericOpForm, bool useLocalScope,
|
||||
bool assumeVerified);
|
||||
|
||||
/// Moves the operation before or after the other operation.
|
||||
void moveAfter(PyOperationBase &other);
|
||||
|
||||
@@ -175,7 +175,8 @@ def testBuildFuncOp():
|
||||
# CHECK-LABEL: TEST: testFuncArgumentAccess
|
||||
@run
|
||||
def testFuncArgumentAccess():
|
||||
with Context(), Location.unknown():
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
@@ -185,38 +186,38 @@ def testFuncArgumentAccess():
|
||||
std.ReturnOp(func.arguments)
|
||||
func.arg_attrs = ArrayAttr.get([
|
||||
DictAttr.get({
|
||||
"foo": StringAttr.get("bar"),
|
||||
"baz": UnitAttr.get()
|
||||
"custom_dialect.foo": StringAttr.get("bar"),
|
||||
"custom_dialect.baz": UnitAttr.get()
|
||||
}),
|
||||
DictAttr.get({"qux": ArrayAttr.get([])})
|
||||
DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
|
||||
])
|
||||
func.result_attrs = ArrayAttr.get([
|
||||
DictAttr.get({"res1": FloatAttr.get(f32, 42.0)}),
|
||||
DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
|
||||
DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
|
||||
DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
|
||||
])
|
||||
|
||||
other = builtin.FuncOp("other_func", ([f32, f32], []))
|
||||
with InsertionPoint(other.add_entry_block()):
|
||||
std.ReturnOp([])
|
||||
other.arg_attrs = [
|
||||
DictAttr.get({"foo": StringAttr.get("qux")}),
|
||||
DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
|
||||
DictAttr.get()
|
||||
]
|
||||
|
||||
# CHECK: [{baz, foo = "bar"}, {qux = []}]
|
||||
# CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
|
||||
print(func.arg_attrs)
|
||||
|
||||
# CHECK: [{res1 = 4.200000e+01 : f32}, {res2 = 2.560000e+02 : f64}]
|
||||
# CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
|
||||
print(func.result_attrs)
|
||||
|
||||
# CHECK: func @some_func(
|
||||
# CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
|
||||
# CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
|
||||
# CHECK: f32 {res1 = 4.200000e+01 : f32},
|
||||
# CHECK: f32 {res2 = 2.560000e+02 : f64})
|
||||
# CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
|
||||
# CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
|
||||
# CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
|
||||
# CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
|
||||
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
|
||||
#
|
||||
# CHECK: func @other_func(
|
||||
# CHECK: %{{.*}}: f32 {foo = "qux"},
|
||||
# CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
|
||||
# CHECK: %{{.*}}: f32)
|
||||
print(module)
|
||||
|
||||
@@ -405,4 +405,7 @@ with Context() as ctx, Location.unknown():
|
||||
return non_default_op_name(input, outs=[init_result])
|
||||
|
||||
|
||||
print(module)
|
||||
# TODO: Fix me! Conv and pooling ops above do not verify, which was uncovered
|
||||
# when switching to more robust module verification. For now, reverting to the
|
||||
# old behavior which does not verify on module print.
|
||||
print(module.operation.get_asm(assume_verified=True))
|
||||
|
||||
@@ -83,49 +83,6 @@ def testFill():
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testStructuredOpOnTensors
|
||||
@run
|
||||
def testStructuredOpOnTensors():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
tensor_type = RankedTensorType.get((2, 3, 4), f32)
|
||||
with InsertionPoint(module.body):
|
||||
func = builtin.FuncOp(
|
||||
name="matmul_test",
|
||||
type=FunctionType.get(
|
||||
inputs=[tensor_type, tensor_type], results=[tensor_type]))
|
||||
with InsertionPoint(func.add_entry_block()):
|
||||
lhs, rhs = func.entry_block.arguments
|
||||
result = linalg.MatmulOp([lhs, rhs], results=[tensor_type]).result
|
||||
std.ReturnOp([result])
|
||||
|
||||
# CHECK: %[[R:.*]] = linalg.matmul ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testStructuredOpOnBuffers
|
||||
@run
|
||||
def testStructuredOpOnBuffers():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
memref_type = MemRefType.get((2, 3, 4), f32)
|
||||
with InsertionPoint(module.body):
|
||||
func = builtin.FuncOp(
|
||||
name="matmul_test",
|
||||
type=FunctionType.get(
|
||||
inputs=[memref_type, memref_type, memref_type], results=[]))
|
||||
with InsertionPoint(func.add_entry_block()):
|
||||
lhs, rhs, result = func.entry_block.arguments
|
||||
# TODO: prperly hook up the region.
|
||||
linalg.MatmulOp([lhs, rhs], outputs=[result])
|
||||
std.ReturnOp([])
|
||||
|
||||
# CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
|
||||
@run
|
||||
def testNamedStructuredOpCustomForm():
|
||||
|
||||
@@ -22,7 +22,8 @@ def testConstShape():
|
||||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((12, -1), f32))
|
||||
def const_shape_tensor(arg):
|
||||
return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20])))
|
||||
return shape.ConstShapeOp(
|
||||
DenseElementsAttr.get(np.array([10, 20]), type=IndexType.get()))
|
||||
|
||||
# CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
|
||||
# CHECK: shape.const_shape [10, 20] : tensor<2xindex>
|
||||
|
||||
@@ -78,8 +78,11 @@ def testConstantIndexOp():
|
||||
@constructAndPrintInModule
|
||||
def testFunctionCalls():
|
||||
foo = builtin.FuncOp("foo", ([], []))
|
||||
foo.sym_visibility = StringAttr.get("private")
|
||||
bar = builtin.FuncOp("bar", ([], [IndexType.get()]))
|
||||
bar.sym_visibility = StringAttr.get("private")
|
||||
qux = builtin.FuncOp("qux", ([], [F32Type.get()]))
|
||||
qux.sym_visibility = StringAttr.get("private")
|
||||
|
||||
with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()):
|
||||
std.CallOp(foo, [])
|
||||
@@ -88,9 +91,9 @@ def testFunctionCalls():
|
||||
std.ReturnOp([])
|
||||
|
||||
|
||||
# CHECK: func @foo()
|
||||
# CHECK: func @bar() -> index
|
||||
# CHECK: func @qux() -> f32
|
||||
# CHECK: func private @foo()
|
||||
# CHECK: func private @bar() -> index
|
||||
# CHECK: func private @qux() -> f32
|
||||
# CHECK: func @caller() {
|
||||
# CHECK: call @foo() : () -> ()
|
||||
# CHECK: %0 = call @bar() : () -> index
|
||||
|
||||
@@ -8,11 +8,13 @@ def run(f):
|
||||
f()
|
||||
gc.collect()
|
||||
assert Context._get_live_count() == 0
|
||||
return f
|
||||
|
||||
|
||||
# Verify successful parse.
|
||||
# CHECK-LABEL: TEST: testParseSuccess
|
||||
# CHECK: module @successfulParse
|
||||
@run
|
||||
def testParseSuccess():
|
||||
ctx = Context()
|
||||
module = Module.parse(r"""module @successfulParse {}""", ctx)
|
||||
@@ -23,12 +25,11 @@ def testParseSuccess():
|
||||
module.dump() # Just outputs to stderr. Verifies that it functions.
|
||||
print(str(module))
|
||||
|
||||
run(testParseSuccess)
|
||||
|
||||
|
||||
# Verify parse error.
|
||||
# CHECK-LABEL: TEST: testParseError
|
||||
# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
|
||||
@run
|
||||
def testParseError():
|
||||
ctx = Context()
|
||||
try:
|
||||
@@ -38,12 +39,11 @@ def testParseError():
|
||||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
run(testParseError)
|
||||
|
||||
|
||||
# Verify successful parse.
|
||||
# CHECK-LABEL: TEST: testCreateEmpty
|
||||
# CHECK: module {
|
||||
@run
|
||||
def testCreateEmpty():
|
||||
ctx = Context()
|
||||
loc = Location.unknown(ctx)
|
||||
@@ -53,8 +53,6 @@ def testCreateEmpty():
|
||||
gc.collect()
|
||||
print(str(module))
|
||||
|
||||
run(testCreateEmpty)
|
||||
|
||||
|
||||
# Verify round-trip of ASM that contains unicode.
|
||||
# Note that this does not test that the print path converts unicode properly
|
||||
@@ -62,6 +60,7 @@ run(testCreateEmpty)
|
||||
# CHECK-LABEL: TEST: testRoundtripUnicode
|
||||
# CHECK: func private @roundtripUnicode()
|
||||
# CHECK: foo = "\F0\9F\98\8A"
|
||||
@run
|
||||
def testRoundtripUnicode():
|
||||
ctx = Context()
|
||||
module = Module.parse(r"""
|
||||
@@ -69,11 +68,28 @@ def testRoundtripUnicode():
|
||||
""", ctx)
|
||||
print(str(module))
|
||||
|
||||
run(testRoundtripUnicode)
|
||||
|
||||
# Verify round-trip of ASM that contains unicode.
|
||||
# Note that this does not test that the print path converts unicode properly
|
||||
# because MLIR asm always normalizes it to the hex encoding.
|
||||
# CHECK-LABEL: TEST: testRoundtripBinary
|
||||
# CHECK: func private @roundtripUnicode()
|
||||
# CHECK: foo = "\F0\9F\98\8A"
|
||||
@run
|
||||
def testRoundtripBinary():
|
||||
with Context():
|
||||
module = Module.parse(r"""
|
||||
func private @roundtripUnicode() attributes { foo = "😊" }
|
||||
""")
|
||||
binary_asm = module.operation.get_asm(binary=True)
|
||||
assert isinstance(binary_asm, bytes)
|
||||
module = Module.parse(binary_asm)
|
||||
print(module)
|
||||
|
||||
|
||||
# Tests that module.operation works and correctly interns instances.
|
||||
# CHECK-LABEL: TEST: testModuleOperation
|
||||
@run
|
||||
def testModuleOperation():
|
||||
ctx = Context()
|
||||
module = Module.parse(r"""module @successfulParse {}""", ctx)
|
||||
@@ -101,10 +117,9 @@ def testModuleOperation():
|
||||
assert ctx._get_live_operation_count() == 0
|
||||
assert ctx._get_live_module_count() == 0
|
||||
|
||||
run(testModuleOperation)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testModuleCapsule
|
||||
@run
|
||||
def testModuleCapsule():
|
||||
ctx = Context()
|
||||
module = Module.parse(r"""module @successfulParse {}""", ctx)
|
||||
@@ -122,5 +137,3 @@ def testModuleCapsule():
|
||||
gc.collect()
|
||||
assert ctx._get_live_module_count() == 0
|
||||
|
||||
|
||||
run(testModuleCapsule)
|
||||
|
||||
@@ -630,21 +630,50 @@ def testSingleResultProperty():
|
||||
print(module.body.operations[2])
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testPrintInvalidOperation
|
||||
def create_invalid_operation():
|
||||
# This module has two region and is invalid verify that we fallback
|
||||
# to the generic printer for safety.
|
||||
op = Operation.create("builtin.module", regions=2)
|
||||
op.regions[0].blocks.append()
|
||||
return op
|
||||
|
||||
# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
|
||||
@run
|
||||
def testPrintInvalidOperation():
|
||||
def testInvalidOperationStrSoftFails():
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx):
|
||||
module = Operation.create("builtin.module", regions=2)
|
||||
# This module has two region and is invalid verify that we fallback
|
||||
# to the generic printer for safety.
|
||||
block = module.regions[0].blocks.append()
|
||||
invalid_op = create_invalid_operation()
|
||||
# Verify that we fallback to the generic printer for safety.
|
||||
# CHECK: // Verification failed, printing generic form
|
||||
# CHECK: "builtin.module"() ( {
|
||||
# CHECK: }) : () -> ()
|
||||
print(module)
|
||||
print(invalid_op)
|
||||
# CHECK: .verify = False
|
||||
print(f".verify = {module.operation.verify()}")
|
||||
print(f".verify = {invalid_op.operation.verify()}")
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
|
||||
@run
|
||||
def testInvalidModuleStrSoftFails():
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx):
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
invalid_op = create_invalid_operation()
|
||||
# Verify that we fallback to the generic printer for safety.
|
||||
# CHECK: // Verification failed, printing generic form
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
|
||||
@run
|
||||
def testInvalidOperationGetAsmBinarySoftFails():
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx):
|
||||
invalid_op = create_invalid_operation()
|
||||
# Verify that we fallback to the generic printer for safety.
|
||||
# CHECK: b'// Verification failed, printing generic form\n
|
||||
print(invalid_op.get_asm(binary=True))
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
|
||||
|
||||
Reference in New Issue
Block a user