[mlir][python] Add bindings for diagnostic handler.

I considered multiple approaches for this but settled on this one because I could make the lifetime management work in a reasonably easy way (others had issues with not being able to cast to a Python reference from a C++ constructor). We could stand to have more formatting helpers, but best to get the core mechanism in first.

Differential Revision: https://reviews.llvm.org/D116568
This commit is contained in:
Stella Laurenzo
2022-01-03 16:39:58 -08:00
parent 78f5014fea
commit 7ee25bc56f
4 changed files with 443 additions and 3 deletions

View File

@@ -511,6 +511,57 @@ void PyMlirContext::contextExit(const pybind11::object &excType,
PyThreadContextEntry::popContext(*this);
}
py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
// Note that ownership is transferred to the delete callback below by way of
// an explicit inc_ref (borrow).
PyDiagnosticHandler *pyHandler =
new PyDiagnosticHandler(get(), std::move(callback));
py::object pyHandlerObject =
py::cast(pyHandler, py::return_value_policy::take_ownership);
pyHandlerObject.inc_ref();
// In these C callbacks, the userData is a PyDiagnosticHandler* that is
// guaranteed to be known to pybind.
auto handlerCallback =
+[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
py::object pyDiagnosticObject =
py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
bool result = false;
{
// Since this can be called from arbitrary C++ contexts, always get the
// gil.
py::gil_scoped_acquire gil;
try {
result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
} catch (std::exception &e) {
fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
e.what());
pyHandler->hadError = true;
}
}
pyDiagnostic->invalidate();
return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
};
auto deleteCallback = +[](void *userData) {
auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
assert(pyHandler->registeredID && "handler is not registered");
pyHandler->registeredID.reset();
// Decrement reference, balancing the inc_ref() above.
py::object pyHandlerObject =
py::cast(pyHandler, py::return_value_policy::reference);
pyHandlerObject.dec_ref();
};
pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
return pyHandlerObject;
}
PyMlirContext &DefaultingPyMlirContext::resolve() {
PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
if (!context) {
@@ -656,6 +707,78 @@ void PyThreadContextEntry::popLocation(PyLocation &location) {
stack.pop_back();
}
//------------------------------------------------------------------------------
// PyDiagnostic*
//------------------------------------------------------------------------------
void PyDiagnostic::invalidate() {
valid = false;
if (materializedNotes) {
for (auto &noteObject : *materializedNotes) {
PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
note->invalidate();
}
}
}
PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
py::object callback)
: context(context), callback(std::move(callback)) {}
PyDiagnosticHandler::~PyDiagnosticHandler() {}
void PyDiagnosticHandler::detach() {
if (!registeredID)
return;
MlirDiagnosticHandlerID localID = *registeredID;
mlirContextDetachDiagnosticHandler(context, localID);
assert(!registeredID && "should have unregistered");
// Not strictly necessary but keeps stale pointers from being around to cause
// issues.
context = {nullptr};
}
void PyDiagnostic::checkValid() {
if (!valid) {
throw std::invalid_argument(
"Diagnostic is invalid (used outside of callback)");
}
}
MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
checkValid();
return mlirDiagnosticGetSeverity(diagnostic);
}
PyLocation PyDiagnostic::getLocation() {
checkValid();
MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
MlirContext context = mlirLocationGetContext(loc);
return PyLocation(PyMlirContext::forContext(context), loc);
}
py::str PyDiagnostic::getMessage() {
checkValid();
py::object fileObject = py::module::import("io").attr("StringIO")();
PyFileAccumulator accum(fileObject, /*binary=*/false);
mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
return fileObject.attr("getvalue")();
}
py::tuple PyDiagnostic::getNotes() {
checkValid();
if (materializedNotes)
return *materializedNotes;
intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
materializedNotes = py::tuple(numNotes);
for (intptr_t i = 0; i < numNotes; ++i) {
MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
}
return *materializedNotes;
}
//------------------------------------------------------------------------------
// PyDialect, PyDialectDescriptor, PyDialects
//------------------------------------------------------------------------------
@@ -2024,6 +2147,36 @@ private:
//------------------------------------------------------------------------------
void mlir::python::populateIRCore(py::module &m) {
//----------------------------------------------------------------------------
// Enums.
//----------------------------------------------------------------------------
py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
.value("ERROR", MlirDiagnosticError)
.value("WARNING", MlirDiagnosticWarning)
.value("NOTE", MlirDiagnosticNote)
.value("REMARK", MlirDiagnosticRemark);
//----------------------------------------------------------------------------
// Mapping of Diagnostics.
//----------------------------------------------------------------------------
py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
.def_property_readonly("severity", &PyDiagnostic::getSeverity)
.def_property_readonly("location", &PyDiagnostic::getLocation)
.def_property_readonly("message", &PyDiagnostic::getMessage)
.def_property_readonly("notes", &PyDiagnostic::getNotes)
.def("__str__", [](PyDiagnostic &self) -> py::str {
if (!self.isValid())
return "<Invalid Diagnostic>";
return self.getMessage();
});
py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
.def("detach", &PyDiagnosticHandler::detach)
.def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
.def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
.def("__enter__", &PyDiagnosticHandler::contextEnter)
.def("__exit__", &PyDiagnosticHandler::contextExit);
//----------------------------------------------------------------------------
// Mapping of MlirContext.
//----------------------------------------------------------------------------
@@ -2079,6 +2232,9 @@ void mlir::python::populateIRCore(py::module &m) {
[](PyMlirContext &self, bool value) {
mlirContextSetAllowUnregisteredDialects(self.get(), value);
})
.def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
py::arg("callback"),
"Attaches a diagnostic handler that will receive callbacks")
.def(
"enable_multithreading",
[](PyMlirContext &self, bool enable) {
@@ -2204,7 +2360,8 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("context") = py::none(), kContextGetFileLocationDocstring)
.def_static(
"fused",
[](const std::vector<PyLocation> &pyLocations, llvm::Optional<PyAttribute> metadata,
[](const std::vector<PyLocation> &pyLocations,
llvm::Optional<PyAttribute> metadata,
DefaultingPyMlirContext context) {
if (pyLocations.empty())
throw py::value_error("No locations provided");
@@ -2236,6 +2393,12 @@ void mlir::python::populateIRCore(py::module &m) {
"context",
[](PyLocation &self) { return self.getContext().getObject(); },
"Context that owns the Location")
.def(
"emit_error",
[](PyLocation &self, std::string message) {
mlirEmitError(self, message.c_str());
},
py::arg("message"), "Emits an error at this location")
.def("__repr__", [](PyLocation &self) {
PyPrintAccumulator printAccum;
mlirLocationPrint(self, printAccum.getCallback(),

View File

@@ -15,6 +15,7 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "llvm/ADT/DenseMap.h"
@@ -24,6 +25,8 @@ namespace mlir {
namespace python {
class PyBlock;
class PyDiagnostic;
class PyDiagnosticHandler;
class PyInsertionPoint;
class PyLocation;
class DefaultingPyLocation;
@@ -207,6 +210,10 @@ public:
const pybind11::object &excVal,
const pybind11::object &excTb);
/// Attaches a Python callback as a diagnostic handler, returning a
/// registration object (internally a PyDiagnosticHandler).
pybind11::object attachDiagnosticHandler(pybind11::object callback);
private:
PyMlirContext(MlirContext context);
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
@@ -267,6 +274,75 @@ private:
PyMlirContextRef contextRef;
};
/// Python class mirroring the C MlirDiagnostic struct. Note that these structs
/// are only valid for the duration of a diagnostic callback and attempting
/// to access them outside of that will raise an exception. This applies to
/// nested diagnostics (in the notes) as well.
class PyDiagnostic {
public:
PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
void invalidate();
bool isValid() { return valid; }
MlirDiagnosticSeverity getSeverity();
PyLocation getLocation();
pybind11::str getMessage();
pybind11::tuple getNotes();
private:
MlirDiagnostic diagnostic;
void checkValid();
/// If notes have been materialized from the diagnostic, then this will
/// be populated with the corresponding objects (all castable to
/// PyDiagnostic).
llvm::Optional<pybind11::tuple> materializedNotes;
bool valid = true;
};
/// Represents a diagnostic handler attached to the context. The handler's
/// callback will be invoked with PyDiagnostic instances until the detach()
/// method is called or the context is destroyed. A diagnostic handler can be
/// the subject of a `with` block, which will detach it when the block exits.
///
/// Since diagnostic handlers can call back into Python code which can do
/// unsafe things (i.e. recursively emitting diagnostics, raising exceptions,
/// etc), this is generally not deemed to be a great user-level API. Users
/// should generally use some form of DiagnosticCollector. If the handler raises
/// any exceptions, they will just be emitted to stderr and dropped.
///
/// The unique usage of this class means that its lifetime management is
/// different from most other parts of the API. Instances are always created
/// in an attached state and can transition to a detached state by either:
/// a) The context being destroyed and unregistering all handlers.
/// b) An explicit call to detach().
/// The object may remain live from a Python perspective for an arbitrary time
/// after detachment, but there is nothing the user can do with it (since there
/// is no way to attach an existing handler object).
class PyDiagnosticHandler {
public:
PyDiagnosticHandler(MlirContext context, pybind11::object callback);
~PyDiagnosticHandler();
bool isAttached() { return registeredID.hasValue(); }
bool getHadError() { return hadError; }
/// Detaches the handler. Does nothing if not attached.
void detach();
pybind11::object contextEnter() { return pybind11::cast(this); }
void contextExit(pybind11::object excType, pybind11::object excVal,
pybind11::object excTb) {
detach();
}
private:
MlirContext context;
pybind11::object callback;
llvm::Optional<MlirDiagnosticHandlerID> registeredID;
bool hadError = false;
friend class PyMlirContext;
};
/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
/// order to differentiate it from the `Dialect` base class which is extended by
/// plugins which extend dialect functionality through extension python code.

View File

@@ -7,7 +7,7 @@
# * Local edits to signatures and types that MyPy did not auto detect (or
# detected incorrectly).
from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence
from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple
from typing import overload
@@ -43,6 +43,9 @@ __all__ = [
"Dialect",
"DialectDescriptor",
"Dialects",
"Diagnostic",
"DiagnosticHandler",
"DiagnosticSeverity",
"DictAttr",
"F16Type",
"F32Type",
@@ -425,8 +428,9 @@ class Context:
def _get_live_count() -> int: ...
def _get_live_module_count(self) -> int: ...
def _get_live_operation_count(self) -> int: ...
def attach_diagnostic_handler(self, callback: Callable[["Diagnostic"], bool]) -> "DiagnosticHandler": ...
def enable_multithreading(self, enable: bool) -> None: ...
def get_dialect_descriptor(name: dialect_name: str) -> "DialectDescriptor": ...
def get_dialect_descriptor(dialect_name: str) -> "DialectDescriptor": ...
def is_registered_operation(self, operation_name: str) -> bool: ...
def __enter__(self) -> "Context": ...
def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
@@ -479,6 +483,31 @@ class Dialects:
def __getattr__(self, arg0: str) -> "Dialect": ...
def __getitem__(self, arg0: str) -> "Dialect": ...
class Diagnostic:
@property
def severity(self) -> "DiagnosticSeverity": ...
@property
def location(self) -> "Location": ...
@property
def message(self) -> str: ...
@property
def notes(self) -> Tuple["Diagnostic"]: ...
class DiagnosticHandler:
def detach(self) -> None: ...
@property
def attached(self) -> bool: ...
@property
def had_error(self) -> bool: ...
def __enter__(self) -> "DiagnosticHandler": ...
def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
class DiagnosticSeverity:
ERROR: "DiagnosticSeverity"
WARNING: "DiagnosticSeverity"
NOTE: "DiagnosticSeverity"
REMARK: "DiagnosticSeverity"
# TODO: Auto-generated. Audit and fix.
class DictAttr(Attribute):
def __init__(self, cast_from_attr: Attribute) -> None: ...

View File

@@ -0,0 +1,172 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
@run
def testLifecycleContextDestroy():
ctx = Context()
def callback(foo): ...
handler = ctx.attach_diagnostic_handler(callback)
assert handler.attached
# If context is destroyed before the handler, it should auto-detach.
ctx = None
gc.collect()
assert not handler.attached
# And finally collecting the handler should be fine.
handler = None
gc.collect()
@run
def testLifecycleExplicitDetach():
ctx = Context()
def callback(foo): ...
handler = ctx.attach_diagnostic_handler(callback)
assert handler.attached
handler.detach()
assert not handler.attached
@run
def testLifecycleWith():
ctx = Context()
def callback(foo): ...
with ctx.attach_diagnostic_handler(callback) as handler:
assert handler.attached
assert not handler.attached
@run
def testLifecycleWithAndExplicitDetach():
ctx = Context()
def callback(foo): ...
with ctx.attach_diagnostic_handler(callback) as handler:
assert handler.attached
handler.detach()
assert not handler.attached
# CHECK-LABEL: TEST: testDiagnosticCallback
@run
def testDiagnosticCallback():
ctx = Context()
def callback(d):
# CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
print(f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}")
return True
handler = ctx.attach_diagnostic_handler(callback)
loc = Location.unknown(ctx)
loc.emit_error("foobar")
assert not handler.had_error
# CHECK-LABEL: TEST: testDiagnosticEmptyNotes
# TODO: Come up with a way to inject a diagnostic with notes from this API.
@run
def testDiagnosticEmptyNotes():
ctx = Context()
def callback(d):
# CHECK: DIAGNOSTIC: notes=()
print(f"DIAGNOSTIC: notes={d.notes}")
return True
handler = ctx.attach_diagnostic_handler(callback)
loc = Location.unknown(ctx)
loc.emit_error("foobar")
assert not handler.had_error
# CHECK-LABEL: TEST: testDiagnosticCallbackException
@run
def testDiagnosticCallbackException():
ctx = Context()
def callback(d):
raise ValueError("Error in handler")
handler = ctx.attach_diagnostic_handler(callback)
loc = Location.unknown(ctx)
loc.emit_error("foobar")
assert handler.had_error
# CHECK-LABEL: TEST: testEscapingDiagnostic
@run
def testEscapingDiagnostic():
ctx = Context()
diags = []
def callback(d):
diags.append(d)
return True
handler = ctx.attach_diagnostic_handler(callback)
loc = Location.unknown(ctx)
loc.emit_error("foobar")
assert not handler.had_error
# CHECK: DIAGNOSTIC: <Invalid Diagnostic>
print(f"DIAGNOSTIC: {str(diags[0])}")
try:
diags[0].severity
raise RuntimeError("expected exception")
except ValueError:
pass
try:
diags[0].location
raise RuntimeError("expected exception")
except ValueError:
pass
try:
diags[0].message
raise RuntimeError("expected exception")
except ValueError:
pass
try:
diags[0].notes
raise RuntimeError("expected exception")
except ValueError:
pass
# CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles
@run
def testDiagnosticReturnTrueHandles():
ctx = Context()
def callback1(d):
print(f"CALLBACK1: {d}")
return True
def callback2(d):
print(f"CALLBACK2: {d}")
return True
ctx.attach_diagnostic_handler(callback1)
ctx.attach_diagnostic_handler(callback2)
loc = Location.unknown(ctx)
# CHECK-NOT: CALLBACK1
# CHECK: CALLBACK2: foobar
# CHECK-NOT: CALLBACK1
loc.emit_error("foobar")
# CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle
@run
def testDiagnosticReturnFalseDoesNotHandle():
ctx = Context()
def callback1(d):
print(f"CALLBACK1: {d}")
return True
def callback2(d):
print(f"CALLBACK2: {d}")
return False
ctx.attach_diagnostic_handler(callback1)
ctx.attach_diagnostic_handler(callback2)
loc = Location.unknown(ctx)
# CHECK: CALLBACK2: foobar
# CHECK: CALLBACK1: foobar
loc.emit_error("foobar")