[mlir] Better Python diagnostics (#128581)

Updated the Python diagnostics handler to emit notes (in addition to
errors) into the output stream so that users have more context as to
where in the IR the error is occurring.
This commit is contained in:
Nikhil Kalra
2025-03-10 15:59:47 -07:00
committed by GitHub
parent 9b066f0b57
commit b15ccd436a
5 changed files with 63 additions and 13 deletions

View File

@@ -9,12 +9,13 @@
#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
#include <cassert>
#include <string>
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cstdint>
#include <string>
namespace mlir {
namespace python {
@@ -24,33 +25,45 @@ namespace python {
class CollectDiagnosticsToStringScope {
public:
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
/*deleteUserData=*/nullptr);
handlerID =
mlirContextAttachDiagnosticHandler(ctx, &handler, &messageStream,
/*deleteUserData=*/nullptr);
}
~CollectDiagnosticsToStringScope() {
assert(errorMessage.empty() && "unchecked error message");
assert(message.empty() && "unchecked error message");
mlirContextDetachDiagnosticHandler(context, handlerID);
}
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
[[nodiscard]] std::string takeMessage() {
std::string newMessage;
std::swap(message, newMessage);
return newMessage;
}
private:
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
auto printer = +[](MlirStringRef message, void *data) {
*static_cast<std::string *>(data) +=
llvm::StringRef(message.data, message.length);
*static_cast<llvm::raw_string_ostream *>(data)
<< std::string_view(message.data, message.length);
};
MlirLocation loc = mlirDiagnosticGetLocation(diag);
*static_cast<std::string *>(data) += "at ";
*static_cast<llvm::raw_string_ostream *>(data) << "at ";
mlirLocationPrint(loc, printer, data);
*static_cast<std::string *>(data) += ": ";
*static_cast<llvm::raw_string_ostream *>(data) << ": ";
mlirDiagnosticPrint(diag, printer, data);
for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) {
*static_cast<llvm::raw_string_ostream *>(data) << "\n";
MlirDiagnostic note = mlirDiagnosticGetNote(diag, i);
handler(note, data);
}
return mlirLogicalResultSuccess();
}
MlirContext context;
MlirDiagnosticHandlerID handlerID;
std::string errorMessage = "";
std::string message;
llvm::raw_string_ostream messageStream{message};
};
} // namespace python

View File

@@ -2,6 +2,9 @@
import gc
from mlir.ir import *
from mlir._mlir_libs._mlirPythonTestNanobind import (
test_diagnostics_with_errors_and_notes,
)
def run(f):
@@ -222,3 +225,16 @@ def testDiagnosticReturnFalseDoesNotHandle():
# CHECK: CALLBACK2: foobar
# CHECK: CALLBACK1: foobar
loc.emit_error("foobar")
# CHECK-LABEL: TEST: testBuiltInDiagnosticsHandler
@run
def testBuiltInDiagnosticsHandler():
ctx = Context()
try:
test_diagnostics_with_errors_and_notes(ctx)
except ValueError as e:
# CHECK: created error
# CHECK: attached note
print(e)

View File

@@ -11,6 +11,8 @@
#include "mlir-c/BuiltinTypes.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test,
python_test::PythonTestDialect)
@@ -42,3 +44,9 @@ MlirTypeID mlirPythonTestTestTypeGetTypeID(void) {
bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
return mlirTypeIsATensor(wrap(unwrap(value).getType()));
}
void mlirPythonTestEmitDiagnosticWithNote(MlirContext ctx) {
auto diag =
mlir::emitError(unwrap(mlirLocationUnknownGet(ctx)), "created error");
diag.attachNote() << "attached note";
}

View File

@@ -10,6 +10,7 @@
#define MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
@@ -33,6 +34,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestTypeGetTypeID(void);
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
MLIR_CAPI_EXPORTED void mlirPythonTestEmitDiagnosticWithNote(MlirContext ctx);
#ifdef __cplusplus
}
#endif

View File

@@ -11,9 +11,12 @@
#include "PythonTestCAPI.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
@@ -45,6 +48,13 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
},
nb::arg("registry"));
m.def("test_diagnostics_with_errors_and_notes", [](MlirContext ctx) {
mlir::python::CollectDiagnosticsToStringScope handler(ctx);
mlirPythonTestEmitDiagnosticWithNote(ctx);
throw nb::value_error(handler.takeMessage().c_str());
});
mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute,
mlirPythonTestTestAttributeGetTypeID)