Files
llvm/mlir/lib/Bindings/Python/PybindUtils.h

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

278 lines
9.7 KiB
C
Raw Normal View History

//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
#include "mlir-c/Support.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/Twine.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace mlir {
namespace python {
// Sets a python error, ready to be thrown to return control back to the
// python runtime.
// Correct usage:
// throw SetPyError(PyExc_ValueError, "Foobar'd");
pybind11::error_already_set SetPyError(PyObject *excClass,
const llvm::Twine &message);
[mlir][Python] Context managers for Context, InsertionPoint, Location. * Finishes support for Context, InsertionPoint and Location to be carried by the thread using context managers. * Introduces type casters and utilities so that DefaultPyMlirContext and DefaultPyLocation in method signatures does the right thing (allows explicit or gets from the thread context). * Extend the rules for the thread context stack to handle nesting, appropriately inheriting and clearing depending on whether the context is the same. * Refactors all method signatures to follow the new convention on trailing parameters for defaulting parameters (loc, ip, context). When the objects are carried in the thread context, this allows most explicit uses of these values to be elided. * Removes the style guide section on putting accessors to construct global objects on the PyMlirContext: this style fails to make good use of the new facility since it is often the only thing remaining needing an MlirContext. * Moves Module parse/creation from mlir.ir.Context to static methods on mlir.ir.Module. * Moves Context.create_operation to a static Operation.create method. * Moves Type parsing from mlir.ir.Context to static methods on mlir.ir.Type. * Moves Attribute parsing from mlir.ir.Context to static methods on mlir.ir.Attribute. * Move Location factory methods from mlir.ir.Context to static methods on mlir.ir.Location. * Refactors the std dialect fake "ODS" generated code to take advantage of the new scheme. Differential Revision: https://reviews.llvm.org/D90547
2020-10-31 23:40:25 -07:00
/// CRTP template for special wrapper types that are allowed to be passed in as
/// 'None' function arguments and can be resolved by some global mechanic if
/// so. Such types will raise an error if this global resolution fails, and
/// it is actually illegal for them to ever be unresolved. From a user
/// perspective, they behave like a smart ptr to the underlying type (i.e.
/// 'get' method and operator-> overloaded).
///
/// Derived types must provide a method, which is called when an environmental
/// resolution is required. It must raise an exception if resolution fails:
/// static ReferrentTy &resolve()
///
/// They must also provide a parameter description that will be used in
/// error messages about mismatched types:
/// static constexpr const char kTypeDescription[] = "<Description>";
template <typename DerivedTy, typename T>
class Defaulting {
public:
using ReferrentTy = T;
/// Type casters require the type to be default constructible, but using
/// such an instance is illegal.
Defaulting() = default;
Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
ReferrentTy *get() { return referrent; }
ReferrentTy *operator->() { return referrent; }
private:
ReferrentTy *referrent = nullptr;
};
} // namespace python
} // namespace mlir
[mlir][Python] Context managers for Context, InsertionPoint, Location. * Finishes support for Context, InsertionPoint and Location to be carried by the thread using context managers. * Introduces type casters and utilities so that DefaultPyMlirContext and DefaultPyLocation in method signatures does the right thing (allows explicit or gets from the thread context). * Extend the rules for the thread context stack to handle nesting, appropriately inheriting and clearing depending on whether the context is the same. * Refactors all method signatures to follow the new convention on trailing parameters for defaulting parameters (loc, ip, context). When the objects are carried in the thread context, this allows most explicit uses of these values to be elided. * Removes the style guide section on putting accessors to construct global objects on the PyMlirContext: this style fails to make good use of the new facility since it is often the only thing remaining needing an MlirContext. * Moves Module parse/creation from mlir.ir.Context to static methods on mlir.ir.Module. * Moves Context.create_operation to a static Operation.create method. * Moves Type parsing from mlir.ir.Context to static methods on mlir.ir.Type. * Moves Attribute parsing from mlir.ir.Context to static methods on mlir.ir.Attribute. * Move Location factory methods from mlir.ir.Context to static methods on mlir.ir.Location. * Refactors the std dialect fake "ODS" generated code to take advantage of the new scheme. Differential Revision: https://reviews.llvm.org/D90547
2020-10-31 23:40:25 -07:00
namespace pybind11 {
namespace detail {
template <typename DefaultingTy>
struct MlirDefaultingCaster {
PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
bool load(pybind11::handle src, bool) {
if (src.is_none()) {
// Note that we do want an exception to propagate from here as it will be
// the most informative.
value = DefaultingTy{DefaultingTy::resolve()};
return true;
}
// Unlike many casters that chain, these casters are expected to always
// succeed, so instead of doing an isinstance check followed by a cast,
// just cast in one step and handle the exception. Returning false (vs
// letting the exception propagate) causes higher level signature parsing
// code to produce nice error messages (other than "Cannot cast...").
try {
value = DefaultingTy{
pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
return true;
} catch (std::exception &) {
[mlir][Python] Context managers for Context, InsertionPoint, Location. * Finishes support for Context, InsertionPoint and Location to be carried by the thread using context managers. * Introduces type casters and utilities so that DefaultPyMlirContext and DefaultPyLocation in method signatures does the right thing (allows explicit or gets from the thread context). * Extend the rules for the thread context stack to handle nesting, appropriately inheriting and clearing depending on whether the context is the same. * Refactors all method signatures to follow the new convention on trailing parameters for defaulting parameters (loc, ip, context). When the objects are carried in the thread context, this allows most explicit uses of these values to be elided. * Removes the style guide section on putting accessors to construct global objects on the PyMlirContext: this style fails to make good use of the new facility since it is often the only thing remaining needing an MlirContext. * Moves Module parse/creation from mlir.ir.Context to static methods on mlir.ir.Module. * Moves Context.create_operation to a static Operation.create method. * Moves Type parsing from mlir.ir.Context to static methods on mlir.ir.Type. * Moves Attribute parsing from mlir.ir.Context to static methods on mlir.ir.Attribute. * Move Location factory methods from mlir.ir.Context to static methods on mlir.ir.Location. * Refactors the std dialect fake "ODS" generated code to take advantage of the new scheme. Differential Revision: https://reviews.llvm.org/D90547
2020-10-31 23:40:25 -07:00
return false;
}
}
static handle cast(DefaultingTy src, return_value_policy policy,
handle parent) {
return pybind11::cast(src, policy);
}
};
template <typename T>
struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
} // namespace detail
} // namespace pybind11
//------------------------------------------------------------------------------
// Conversion utilities.
//------------------------------------------------------------------------------
namespace mlir {
/// Accumulates into a python string from a method that accepts an
/// MlirStringCallback.
struct PyPrintAccumulator {
pybind11::list parts;
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) {
PyPrintAccumulator *printAccum =
static_cast<PyPrintAccumulator *>(userData);
pybind11::str pyPart(part, size); // Decodes as UTF-8 by default.
printAccum->parts.append(std::move(pyPart));
};
}
pybind11::str join() {
pybind11::str delim("", 0);
return delim.attr("join")(parts);
}
};
/// Accumulates int a python file-like object, either writing text (default)
/// or binary.
class PyFileAccumulator {
public:
PyFileAccumulator(pybind11::object fileObject, bool binary)
: pyWriteFunction(fileObject.attr("write")), binary(binary) {}
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) {
pybind11::gil_scoped_acquire();
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
if (accum->binary) {
// Note: Still has to copy and not avoidable with this API.
pybind11::bytes pyBytes(part, size);
accum->pyWriteFunction(pyBytes);
} else {
pybind11::str pyStr(part, size); // Decodes as UTF-8 by default.
accum->pyWriteFunction(pyStr);
}
};
}
private:
pybind11::object pyWriteFunction;
bool binary;
};
/// Accumulates into a python string from a method that is expected to make
/// one (no more, no less) call to the callback (asserts internally on
/// violation).
struct PySinglePartStringAccumulator {
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) {
PySinglePartStringAccumulator *accum =
static_cast<PySinglePartStringAccumulator *>(userData);
assert(!accum->invoked &&
"PySinglePartStringAccumulator called back multiple times");
accum->invoked = true;
accum->value = pybind11::str(part, size);
};
}
pybind11::str takeValue() {
assert(invoked && "PySinglePartStringAccumulator not called back");
return std::move(value);
}
private:
pybind11::str value;
bool invoked = false;
};
/// A CRTP base class for pseudo-containers willing to support Python-type
/// slicing access on top of indexed access. Calling ::bind on this class
/// will define `__len__` as well as `__getitem__` with integer and slice
/// arguments.
///
/// This is intended for pseudo-containers that can refer to arbitrary slices of
/// underlying storage indexed by a single integer. Indexing those with an
/// integer produces an instance of ElementTy. Indexing those with a slice
/// produces a new instance of Derived, which can be sliced further.
///
/// A derived class must provide the following:
/// - a `static const char *pyClassName ` field containing the name of the
/// Python class to bind;
/// - an instance method `intptr_t getNumElements()` that returns the number
/// of elements in the backing container (NOT that of the slice);
/// - an instance method `ElementTy getElement(intptr_t)` that returns a
/// single element at the given index.
/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
/// constructs a new instance of the derived pseudo-container with the
/// given slice parameters (to be forwarded to the Sliceable constructor).
///
/// A derived class may additionally define:
/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
/// the python class.
template <typename Derived, typename ElementTy>
class Sliceable {
protected:
using ClassTy = pybind11::class_<Derived>;
public:
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
: startIndex(startIndex), length(length), step(step) {
assert(length >= 0 && "expected non-negative slice length");
}
/// Returns the length of the slice.
intptr_t dunderLen() const { return length; }
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Throws if the index is out of bounds.
ElementTy dunderGetItem(intptr_t index) {
// Negative indices mean we count from the end.
if (index < 0)
index = length + index;
if (index < 0 || index >= length) {
throw python::SetPyError(PyExc_IndexError,
"attempt to access out of bounds");
}
// Compute the linear index given the current slice properties.
int linearIndex = index * step + startIndex;
assert(linearIndex >= 0 &&
linearIndex < static_cast<Derived *>(this)->getNumElements() &&
"linear index out of bounds, the slice is ill-formed");
return static_cast<Derived *>(this)->getElement(linearIndex);
}
/// Returns a new instance of the pseudo-container restricted to the given
/// slice.
Derived dunderGetItemSlice(pybind11::slice slice) {
ssize_t start, stop, extraStep, sliceLength;
if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) {
throw python::SetPyError(PyExc_IndexError,
"attempt to access out of bounds");
}
return static_cast<Derived *>(this)->slice(startIndex + start * step,
sliceLength, step * extraStep);
}
/// Binds the indexing and length methods in the Python class.
static void bind(pybind11::module &m) {
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName)
.def("__len__", &Sliceable::dunderLen)
.def("__getitem__", &Sliceable::dunderGetItem)
.def("__getitem__", &Sliceable::dunderGetItemSlice);
Derived::bindDerived(clazz);
}
/// Hook for derived classes willing to bind more methods.
static void bindDerived(ClassTy &) {}
private:
intptr_t startIndex;
intptr_t length;
intptr_t step;
};
} // namespace mlir
#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H