mirror of
https://github.com/intel/llvm.git
synced 2026-01-22 23:49:22 +08:00
Add python bindings for Type and IntegerType.
* The binding for Type is trivial and should be non-controversial. * The way that I define the IntegerType should serve as a pattern for what I want to do next. * I propose defining the rest of the standard types in this fashion and then generalizing for dialect types as necessary. * Essentially, creating/accessing a concrete Type (vs interacting with the string form) is done by "casting" to the concrete type (i.e. IntegerType can be constructed with a Type and will throw if the cast is illegal). * This deviates from some of our previous discussions about global objects but I think produces a usable API and we should go this way. Differential Revision: https://reviews.llvm.org/D86179
This commit is contained in:
@@ -9,7 +9,10 @@
|
||||
#include "IRModules.h"
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir-c/StandardTypes.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -20,6 +23,15 @@ static const char kContextParseDocstring[] =
|
||||
R"(Parses a module's assembly format from a string.
|
||||
|
||||
Returns a new MlirModule or raises a ValueError if the parsing fails.
|
||||
|
||||
See also: https://mlir.llvm.org/docs/LangRef/
|
||||
)";
|
||||
|
||||
static const char kContextParseType[] = R"(Parses the assembly form of a type.
|
||||
|
||||
Returns a Type object or raises a ValueError if the type cannot be parsed.
|
||||
|
||||
See also: https://mlir.llvm.org/docs/LangRef/#type-system
|
||||
)";
|
||||
|
||||
static const char kOperationStrDunderDocstring[] =
|
||||
@@ -30,6 +42,9 @@ use the dedicated print method, which supports keyword arguments to customize
|
||||
behavior.
|
||||
)";
|
||||
|
||||
static const char kTypeStrDunderDocstring[] =
|
||||
R"(Prints the assembly form of the type.)";
|
||||
|
||||
static const char kDumpDocstring[] =
|
||||
R"(Dumps a debug representation of the object to stderr.)";
|
||||
|
||||
@@ -64,39 +79,154 @@ struct PyPrintAccumulator {
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Context Wrapper Class.
|
||||
// PyType.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
PyMlirModule PyMlirContext::parse(const std::string &module) {
|
||||
auto moduleRef = mlirModuleCreateParse(context, module.c_str());
|
||||
if (!moduleRef.ptr) {
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
"Unable to parse module assembly (see diagnostics)");
|
||||
}
|
||||
return PyMlirModule(moduleRef);
|
||||
bool PyType::operator==(const PyType &other) {
|
||||
return mlirTypeEqual(type, other.type);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Module Wrapper Class.
|
||||
// Standard type subclasses.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
void PyMlirModule::dump() { mlirOperationDump(mlirModuleGetOperation(module)); }
|
||||
namespace {
|
||||
|
||||
/// CRTP base classes for Python types that subclass Type and should be
|
||||
/// castable from it (i.e. via something like IntegerType(t)).
|
||||
template <typename T>
|
||||
class PyConcreteType : public PyType {
|
||||
public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
using ClassTy = py::class_<T, PyType>;
|
||||
using IsAFunctionTy = int (*)(MlirType);
|
||||
|
||||
PyConcreteType() = default;
|
||||
PyConcreteType(MlirType t) : PyType(t) {}
|
||||
PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
|
||||
|
||||
static MlirType castFrom(PyType &orig) {
|
||||
if (!T::isaFunction(orig.type)) {
|
||||
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
||||
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
|
||||
T::pyClassName + " (from " +
|
||||
origRepr + ")");
|
||||
}
|
||||
return orig.type;
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
auto class_ = ClassTy(m, T::pyClassName);
|
||||
class_.def(py::init<PyType &>(), py::keep_alive<0, 1>());
|
||||
T::bindDerived(class_);
|
||||
}
|
||||
|
||||
/// Implemented by derived classes to add methods to the Python subclass.
|
||||
static void bindDerived(ClassTy &m) {}
|
||||
};
|
||||
|
||||
class PyIntegerType : public PyConcreteType<PyIntegerType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
|
||||
static constexpr const char *pyClassName = "IntegerType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"signless",
|
||||
[](PyMlirContext &context, unsigned width) {
|
||||
MlirType t = mlirIntegerTypeGet(context.context, width);
|
||||
return PyIntegerType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a signless integer type");
|
||||
c.def_static(
|
||||
"signed",
|
||||
[](PyMlirContext &context, unsigned width) {
|
||||
MlirType t = mlirIntegerTypeSignedGet(context.context, width);
|
||||
return PyIntegerType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create a signed integer type");
|
||||
c.def_static(
|
||||
"unsigned",
|
||||
[](PyMlirContext &context, unsigned width) {
|
||||
MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
|
||||
return PyIntegerType(t);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Create an unsigned integer type");
|
||||
c.def_property_readonly(
|
||||
"width",
|
||||
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
|
||||
"Returns the width of the integer type");
|
||||
c.def_property_readonly(
|
||||
"is_signless",
|
||||
[](PyIntegerType &self) -> bool {
|
||||
return mlirIntegerTypeIsSignless(self.type);
|
||||
},
|
||||
"Returns whether this is a signless integer");
|
||||
c.def_property_readonly(
|
||||
"is_signed",
|
||||
[](PyIntegerType &self) -> bool {
|
||||
return mlirIntegerTypeIsSigned(self.type);
|
||||
},
|
||||
"Returns whether this is a signed integer");
|
||||
c.def_property_readonly(
|
||||
"is_unsigned",
|
||||
[](PyIntegerType &self) -> bool {
|
||||
return mlirIntegerTypeIsUnsigned(self.type);
|
||||
},
|
||||
"Returns whether this is an unsigned integer");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Populates the pybind11 IR submodule.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
void mlir::python::populateIRSubmodule(py::module &m) {
|
||||
py::class_<PyMlirContext>(m, "MlirContext")
|
||||
// Mapping of MlirContext
|
||||
py::class_<PyMlirContext>(m, "Context")
|
||||
.def(py::init<>())
|
||||
.def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>(),
|
||||
kContextParseDocstring);
|
||||
.def(
|
||||
"parse_module",
|
||||
[](PyMlirContext &self, const std::string module) {
|
||||
auto moduleRef =
|
||||
mlirModuleCreateParse(self.context, module.c_str());
|
||||
if (mlirModuleIsNull(moduleRef)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
"Unable to parse module assembly (see diagnostics)");
|
||||
}
|
||||
return PyModule(moduleRef);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextParseDocstring)
|
||||
.def(
|
||||
"parse_type",
|
||||
[](PyMlirContext &self, std::string typeSpec) {
|
||||
MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
|
||||
if (mlirTypeIsNull(type)) {
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Unable to parse type: '") +
|
||||
typeSpec + "'");
|
||||
}
|
||||
return PyType(type);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextParseType);
|
||||
|
||||
py::class_<PyMlirModule>(m, "MlirModule")
|
||||
.def("dump", &PyMlirModule::dump, kDumpDocstring)
|
||||
// Mapping of Module
|
||||
py::class_<PyModule>(m, "Module")
|
||||
.def(
|
||||
"dump",
|
||||
[](PyModule &self) {
|
||||
mlirOperationDump(mlirModuleGetOperation(self.module));
|
||||
},
|
||||
kDumpDocstring)
|
||||
.def(
|
||||
"__str__",
|
||||
[](PyMlirModule &self) {
|
||||
[](PyModule &self) {
|
||||
auto operation = mlirModuleGetOperation(self.module);
|
||||
PyPrintAccumulator printAccum;
|
||||
mlirOperationPrint(operation, printAccum.getCallback(),
|
||||
@@ -104,4 +234,42 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
||||
return printAccum.join();
|
||||
},
|
||||
kOperationStrDunderDocstring);
|
||||
|
||||
// Mapping of Type.
|
||||
py::class_<PyType>(m, "Type")
|
||||
.def("__eq__",
|
||||
[](PyType &self, py::object &other) {
|
||||
try {
|
||||
PyType otherType = other.cast<PyType>();
|
||||
return self == otherType;
|
||||
} catch (std::exception &e) {
|
||||
return false;
|
||||
}
|
||||
})
|
||||
.def(
|
||||
"dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
|
||||
.def(
|
||||
"__str__",
|
||||
[](PyType &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
mlirTypePrint(self.type, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
return printAccum.join();
|
||||
},
|
||||
kTypeStrDunderDocstring)
|
||||
.def("__repr__", [](PyType &self) {
|
||||
// Generally, assembly formats are not printed for __repr__ because
|
||||
// this can cause exceptionally long debug output and exceptions.
|
||||
// However, types are an exception as they typically have compact
|
||||
// assembly forms and printing them is useful.
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append("Type(");
|
||||
mlirTypePrint(self.type, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
});
|
||||
|
||||
// Standard type bindings.
|
||||
PyIntegerType::bind(m);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user