[MLIR][python bindings] Expose TypeIDs in python

This diff adds python bindings for `MlirTypeID`. It paves the way for returning accurately typed `Type`s from python APIs (see D150927) and then further along building type "conscious" `Value` APIs (see D150413).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D150839
This commit is contained in:
max
2023-05-22 11:12:53 -05:00
parent a7c5cf2260
commit d39a784402
11 changed files with 435 additions and 17 deletions

View File

@@ -80,6 +80,8 @@
#define MLIR_PYTHON_CAPSULE_PASS_MANAGER \
MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr")
#define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr")
#define MLIR_PYTHON_CAPSULE_TYPEID \
MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr")
/** Attribute on MLIR Python objects that expose their C-API pointer.
* This will be a type-specific capsule created as per one of the helpers
@@ -268,6 +270,25 @@ static inline MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule) {
return op;
}
/** Creates a capsule object encapsulating the raw C-API MlirTypeID.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the type in any way.
*/
static inline PyObject *mlirPythonTypeIDToCapsule(MlirTypeID typeID) {
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(typeID),
MLIR_PYTHON_CAPSULE_TYPEID, NULL);
}
/** Extracts an MlirTypeID from a capsule as produced from
* mlirPythonTypeIDToCapsule. If the capsule is not of the right type, then
* a null type is returned (as checked via mlirTypeIDIsNull). In such a
* case, the Python APIs will have already set an error. */
static inline MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule) {
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_TYPEID);
MlirTypeID typeID = {ptr};
return typeID;
}
/** Creates a capsule object encapsulating the raw C-API MlirType.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the type in any way.

View File

@@ -22,6 +22,9 @@ extern "C" {
// Integer types.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an Integer type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void);
/// Checks whether the given type is an integer type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type);
@@ -56,6 +59,9 @@ MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type);
// Index type.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an Index type.
MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void);
/// Checks whether the given type is an index type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type);
@@ -67,6 +73,9 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
// Floating-point types.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
/// Checks whether the given type is an f8E5M2 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
@@ -74,6 +83,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
/// Returns the typeID of an Float8E4M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void);
/// Checks whether the given type is an f8E4M3FN type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
@@ -81,6 +93,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
/// Returns the typeID of an Float8E5M2FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void);
/// Checks whether the given type is an f8E5M2FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
@@ -88,6 +103,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx);
/// Returns the typeID of an Float8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void);
/// Checks whether the given type is an f8E4M3FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
@@ -95,6 +113,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
/// Returns the typeID of an Float8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void);
/// Checks whether the given type is an f8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
@@ -102,6 +123,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
/// Returns the typeID of an BFloat16 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
/// Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
@@ -109,6 +133,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx);
/// Returns the typeID of an Float16 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void);
/// Checks whether the given type is an f16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type);
@@ -116,6 +143,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx);
/// Returns the typeID of an Float32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void);
/// Checks whether the given type is an f32 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type);
@@ -123,6 +153,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx);
/// Returns the typeID of an Float64 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void);
/// Checks whether the given type is an f64 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type);
@@ -134,6 +167,9 @@ MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx);
// None type.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an None type.
MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void);
/// Checks whether the given type is a None type.
MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type);
@@ -145,6 +181,9 @@ MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx);
// Complex type.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an Complex type.
MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void);
/// Checks whether the given type is a Complex type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type);
@@ -159,6 +198,9 @@ MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type);
// Shaped type.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an Shaped type.
MLIR_CAPI_EXPORTED MlirTypeID mlirShapedTypeGetTypeID(void);
/// Checks whether the given type is a Shaped type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type);
@@ -202,6 +244,9 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void);
// Vector type.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an Vector type.
MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void);
/// Checks whether the given type is a Vector type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type);
@@ -226,9 +271,15 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
/// Checks whether the given type is a Tensor type.
MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type);
/// Returns the typeID of an RankedTensor type.
MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void);
/// Checks whether the given type is a ranked tensor type.
MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type);
/// Returns the typeID of an UnrankedTensor type.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void);
/// Checks whether the given type is an unranked tensor type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type);
@@ -264,9 +315,15 @@ mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType);
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an MemRef type.
MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void);
/// Checks whether the given type is a MemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type);
/// Returns the typeID of an UnrankedMemRef type.
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void);
/// Checks whether the given type is an UnrankedMemRef type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type);
@@ -326,6 +383,9 @@ mlirUnrankedMemrefGetMemorySpace(MlirType type);
// Tuple type.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an Tuple type.
MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void);
/// Checks whether the given type is a tuple type.
MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type);
@@ -345,6 +405,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos);
// Function type.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an Function type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void);
/// Checks whether the given type is a function type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type);
@@ -373,6 +436,9 @@ MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type,
// Opaque type.
//===----------------------------------------------------------------------===//
/// Returns the typeID of an Opaque type.
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void);
/// Checks whether the given type is an opaque type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type);

View File

@@ -236,6 +236,27 @@ struct type_caster<MlirPassManager> {
}
};
/// Casts object <-> MlirTypeID.
template <>
struct type_caster<MlirTypeID> {
PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"));
bool load(handle src, bool) {
py::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToTypeID(capsule.ptr());
return !mlirTypeIDIsNull(value);
}
static handle cast(MlirTypeID v, return_value_policy, handle) {
if (v.ptr == nullptr)
return py::none();
py::object capsule =
py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v));
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("TypeID")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
};
};
/// Casts object <-> MlirType.
template <>
struct type_caster<MlirType> {

View File

@@ -17,6 +17,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -1807,6 +1808,24 @@ PyType PyType::createFromCapsule(py::object capsule) {
rawType);
}
//------------------------------------------------------------------------------
// PyTypeID.
//------------------------------------------------------------------------------
py::object PyTypeID::getCapsule() {
return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
}
PyTypeID PyTypeID::createFromCapsule(py::object capsule) {
MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
if (mlirTypeIDIsNull(mlirTypeID))
throw py::error_already_set();
return PyTypeID(mlirTypeID);
}
bool PyTypeID::operator==(const PyTypeID &other) const {
return mlirTypeIDEqual(typeID, other.typeID);
}
//------------------------------------------------------------------------------
// PyValue and subclases.
//------------------------------------------------------------------------------
@@ -3268,16 +3287,47 @@ void mlir::python::populateIRCore(py::module &m) {
return printAccum.join();
},
"Returns the assembly form of the type.")
.def("__repr__", [](PyType &self) {
// Generally, assembly formats are not printed for __repr__ because
// this can cause exceptionally long debug output and exceptions.
// However, types are an exception as they typically have compact
// assembly forms and printing them is useful.
PyPrintAccumulator printAccum;
printAccum.parts.append("Type(");
mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
.def("__repr__",
[](PyType &self) {
// Generally, assembly formats are not printed for __repr__ because
// this can cause exceptionally long debug output and exceptions.
// However, types are an exception as they typically have compact
// assembly forms and printing them is useful.
PyPrintAccumulator printAccum;
printAccum.parts.append("Type(");
mlirTypePrint(self, printAccum.getCallback(),
printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
})
.def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
if (!mlirTypeIDIsNull(mlirTypeID))
return mlirTypeID;
auto origRepr =
pybind11::repr(pybind11::cast(self)).cast<std::string>();
throw py::value_error(
(origRepr + llvm::Twine(" has no typeid.")).str());
});
//----------------------------------------------------------------------------
// Mapping of PyTypeID.
//----------------------------------------------------------------------------
py::class_<PyTypeID>(m, "TypeID", py::module_local())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
// Note, this tests whether the underlying TypeIDs are the same,
// not whether the wrapper MlirTypeIDs are the same, nor whether
// the Python objects are the same (i.e., PyTypeID is a value type).
.def("__eq__",
[](PyTypeID &self, PyTypeID &other) { return self == other; })
.def("__eq__",
[](PyTypeID &self, const py::object &other) { return false; })
// Note, this gives the hash value of the underlying TypeID, not the
// hash value of the Python object, nor the hash value of the
// MlirTypeID wrapper.
.def("__hash__", [](PyTypeID &self) {
return static_cast<size_t>(mlirTypeIDHashValue(self));
});
//----------------------------------------------------------------------------

View File

@@ -20,6 +20,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
@@ -826,6 +827,29 @@ private:
MlirType type;
};
/// A TypeID provides an efficient and unique identifier for a specific C++
/// type. This allows for a C++ type to be compared, hashed, and stored in an
/// opaque context. This class wraps around the generic MlirTypeID.
class PyTypeID {
public:
PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
// Note, this tests whether the underlying TypeIDs are the same,
// not whether the wrapper MlirTypeIDs are the same, nor whether
// the PyTypeID objects are the same (i.e., PyTypeID is a value type).
bool operator==(const PyTypeID &other) const;
operator MlirTypeID() const { return typeID; }
MlirTypeID get() { return typeID; }
/// Gets a capsule wrapping the void* within the MlirTypeID.
pybind11::object getCapsule();
/// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
static PyTypeID createFromCapsule(pybind11::object capsule);
private:
MlirTypeID typeID;
};
/// CRTP base classes for Python types that subclass Type and should be
/// castable from it (i.e. via something like IntegerType(t)).
/// By default, type class hierarchies are one level deep (i.e. a
@@ -839,10 +863,14 @@ public:
// const char *pyClassName
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirType);
using GetTypeIDFunctionTy = MlirTypeID (*)();
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
: BaseTy(std::move(contextRef), t) {}
: BaseTy(std::move(contextRef), t) {
pybind11::implicitly_convertible<PyType, DerivedTy>();
}
PyConcreteType(PyType &orig)
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
@@ -866,6 +894,26 @@ public:
return DerivedTy::isaFunction(otherType);
},
pybind11::arg("other"));
cls.def_property_readonly_static(
"static_typeid", [](py::object & /*class*/) -> MlirTypeID {
if (DerivedTy::getTypeIdFunction)
return DerivedTy::getTypeIdFunction();
throw SetPyError(PyExc_AttributeError,
DerivedTy::pyClassName +
llvm::Twine(" has no typeid."));
});
cls.def_property_readonly("typeid", [](PyType &self) {
return py::cast(self).attr("typeid").cast<MlirTypeID>();
});
cls.def("__repr__", [](DerivedTy &self) {
PyPrintAccumulator printAccum;
printAccum.parts.append(DerivedTy::pyClassName);
printAccum.parts.append("(");
mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
printAccum.parts.append(")");
return printAccum.join();
});
DerivedTy::bindDerived(cls);
}

View File

@@ -32,6 +32,8 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
class PyIntegerType : public PyConcreteType<PyIntegerType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirIntegerTypeGetTypeID;
static constexpr const char *pyClassName = "IntegerType";
using PyConcreteType::PyConcreteType;
@@ -89,6 +91,8 @@ public:
class PyIndexType : public PyConcreteType<PyIndexType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirIndexTypeGetTypeID;
static constexpr const char *pyClassName = "IndexType";
using PyConcreteType::PyConcreteType;
@@ -107,6 +111,8 @@ public:
class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E4M3FNTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3FNType";
using PyConcreteType::PyConcreteType;
@@ -125,6 +131,8 @@ public:
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E5M2TypeGetTypeID;
static constexpr const char *pyClassName = "Float8E5M2Type";
using PyConcreteType::PyConcreteType;
@@ -143,6 +151,8 @@ public:
class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E4M3FNUZTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3FNUZType";
using PyConcreteType::PyConcreteType;
@@ -161,6 +171,8 @@ public:
class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E4M3B11FNUZTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
using PyConcreteType::PyConcreteType;
@@ -179,6 +191,8 @@ public:
class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E5M2FNUZTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E5M2FNUZType";
using PyConcreteType::PyConcreteType;
@@ -197,6 +211,8 @@ public:
class PyBF16Type : public PyConcreteType<PyBF16Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirBFloat16TypeGetTypeID;
static constexpr const char *pyClassName = "BF16Type";
using PyConcreteType::PyConcreteType;
@@ -215,6 +231,8 @@ public:
class PyF16Type : public PyConcreteType<PyF16Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat16TypeGetTypeID;
static constexpr const char *pyClassName = "F16Type";
using PyConcreteType::PyConcreteType;
@@ -233,6 +251,8 @@ public:
class PyF32Type : public PyConcreteType<PyF32Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat32TypeGetTypeID;
static constexpr const char *pyClassName = "F32Type";
using PyConcreteType::PyConcreteType;
@@ -251,6 +271,8 @@ public:
class PyF64Type : public PyConcreteType<PyF64Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat64TypeGetTypeID;
static constexpr const char *pyClassName = "F64Type";
using PyConcreteType::PyConcreteType;
@@ -269,6 +291,8 @@ public:
class PyNoneType : public PyConcreteType<PyNoneType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirNoneTypeGetTypeID;
static constexpr const char *pyClassName = "NoneType";
using PyConcreteType::PyConcreteType;
@@ -287,6 +311,8 @@ public:
class PyComplexType : public PyConcreteType<PyComplexType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirComplexTypeGetTypeID;
static constexpr const char *pyClassName = "ComplexType";
using PyConcreteType::PyConcreteType;
@@ -417,6 +443,8 @@ private:
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirVectorTypeGetTypeID;
static constexpr const char *pyClassName = "VectorType";
using PyConcreteType::PyConcreteType;
@@ -442,6 +470,8 @@ class PyRankedTensorType
: public PyConcreteType<PyRankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirRankedTensorTypeGetTypeID;
static constexpr const char *pyClassName = "RankedTensorType";
using PyConcreteType::PyConcreteType;
@@ -476,6 +506,8 @@ class PyUnrankedTensorType
: public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUnrankedTensorTypeGetTypeID;
static constexpr const char *pyClassName = "UnrankedTensorType";
using PyConcreteType::PyConcreteType;
@@ -498,6 +530,8 @@ public:
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirMemRefTypeGetTypeID;
static constexpr const char *pyClassName = "MemRefType";
using PyConcreteType::PyConcreteType;
@@ -550,6 +584,8 @@ class PyUnrankedMemRefType
: public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUnrankedMemRefTypeGetTypeID;
static constexpr const char *pyClassName = "UnrankedMemRefType";
using PyConcreteType::PyConcreteType;
@@ -585,6 +621,8 @@ public:
class PyTupleType : public PyConcreteType<PyTupleType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTupleTypeGetTypeID;
static constexpr const char *pyClassName = "TupleType";
using PyConcreteType::PyConcreteType;
@@ -622,6 +660,8 @@ public:
class PyFunctionType : public PyConcreteType<PyFunctionType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFunctionTypeGetTypeID;
static constexpr const char *pyClassName = "FunctionType";
using PyConcreteType::PyConcreteType;
@@ -676,6 +716,8 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirOpaqueTypeGetTypeID;
static constexpr const char *pyClassName = "OpaqueType";
using PyConcreteType::PyConcreteType;

View File

@@ -22,6 +22,8 @@ using namespace mlir;
// Integer types.
//===----------------------------------------------------------------------===//
MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); }
bool mlirTypeIsAInteger(MlirType type) {
return llvm::isa<IntegerType>(unwrap(type));
}
@@ -58,6 +60,8 @@ bool mlirIntegerTypeIsUnsigned(MlirType type) {
// Index type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); }
bool mlirTypeIsAIndex(MlirType type) {
return llvm::isa<IndexType>(unwrap(type));
}
@@ -70,6 +74,10 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
// Floating-point types.
//===----------------------------------------------------------------------===//
MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
return wrap(Float8E5M2Type::getTypeID());
}
bool mlirTypeIsAFloat8E5M2(MlirType type) {
return unwrap(type).isFloat8E5M2();
}
@@ -78,6 +86,10 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
}
MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
return wrap(Float8E4M3FNType::getTypeID());
}
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
return unwrap(type).isFloat8E4M3FN();
}
@@ -86,6 +98,10 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
}
MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
return wrap(Float8E5M2FNUZType::getTypeID());
}
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
return unwrap(type).isFloat8E5M2FNUZ();
}
@@ -94,6 +110,10 @@ MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
}
MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
return wrap(Float8E4M3FNUZType::getTypeID());
}
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3FNUZ();
}
@@ -102,6 +122,10 @@ MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
}
MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
return wrap(Float8E4M3B11FNUZType::getTypeID());
}
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3B11FNUZ();
}
@@ -110,24 +134,34 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
}
MlirTypeID mlirBFloat16TypeGetTypeID() {
return wrap(BFloat16Type::getTypeID());
}
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
MlirType mlirBF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getBF16(unwrap(ctx)));
}
MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getF16(unwrap(ctx)));
}
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
MlirType mlirF32TypeGet(MlirContext ctx) {
return wrap(FloatType::getF32(unwrap(ctx)));
}
MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
MlirType mlirF64TypeGet(MlirContext ctx) {
@@ -138,6 +172,8 @@ MlirType mlirF64TypeGet(MlirContext ctx) {
// None type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); }
bool mlirTypeIsANone(MlirType type) {
return llvm::isa<NoneType>(unwrap(type));
}
@@ -150,6 +186,8 @@ MlirType mlirNoneTypeGet(MlirContext ctx) {
// Complex type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); }
bool mlirTypeIsAComplex(MlirType type) {
return llvm::isa<ComplexType>(unwrap(type));
}
@@ -214,6 +252,8 @@ int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
// Vector type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); }
bool mlirTypeIsAVector(MlirType type) {
return llvm::isa<VectorType>(unwrap(type));
}
@@ -239,10 +279,18 @@ bool mlirTypeIsATensor(MlirType type) {
return llvm::isa<TensorType>(unwrap(type));
}
MlirTypeID mlirRankedTensorTypeGetTypeID() {
return wrap(RankedTensorType::getTypeID());
}
bool mlirTypeIsARankedTensor(MlirType type) {
return llvm::isa<RankedTensorType>(unwrap(type));
}
MlirTypeID mlirUnrankedTensorTypeGetTypeID() {
return wrap(UnrankedTensorType::getTypeID());
}
bool mlirTypeIsAUnrankedTensor(MlirType type) {
return llvm::isa<UnrankedTensorType>(unwrap(type));
}
@@ -280,6 +328,8 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); }
bool mlirTypeIsAMemRef(MlirType type) {
return llvm::isa<MemRefType>(unwrap(type));
}
@@ -337,6 +387,10 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
return wrap(UnrankedMemRefType::getTypeID());
}
bool mlirTypeIsAUnrankedMemRef(MlirType type) {
return llvm::isa<UnrankedMemRefType>(unwrap(type));
}
@@ -362,6 +416,8 @@ MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
// Tuple type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); }
bool mlirTypeIsATuple(MlirType type) {
return llvm::isa<TupleType>(unwrap(type));
}
@@ -386,6 +442,10 @@ MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
// Function type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirFunctionTypeGetTypeID() {
return wrap(FunctionType::getTypeID());
}
bool mlirTypeIsAFunction(MlirType type) {
return llvm::isa<FunctionType>(unwrap(type));
}
@@ -424,6 +484,8 @@ MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
// Opaque type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); }
bool mlirTypeIsAOpaque(MlirType type) {
return llvm::isa<OpaqueType>(unwrap(type));
}

View File

@@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
def register_python_test_dialect(context, load=True):
from .._mlir_libs import _mlirPythonTest

View File

@@ -299,6 +299,14 @@ def testCustomType():
# The following cast must not assert.
b = test.TestType(a)
# Instance custom types should have typeids
assert isinstance(b.typeid, TypeID)
# Subclasses of ir.Type should not have a static_typeid
# CHECK: 'TestType' object has no attribute 'static_typeid'
try:
b.static_typeid
except AttributeError as e:
print(e)
i8 = IntegerType.get_signless(8)
try:
@@ -353,6 +361,12 @@ def testTensorValue():
# CHECK: False
print(tt.is_null())
# Classes of custom types that inherit from concrete types should have
# static_typeid
assert isinstance(test.TestTensorType.static_typeid, TypeID)
# And it should be equal to the in-tree concrete type
assert test.TestTensorType.static_typeid == t.type.typeid
# CHECK-LABEL: TEST: inferReturnTypeComponents
@run

View File

@@ -3,6 +3,7 @@
import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
@@ -76,6 +77,7 @@ def testTypeHash():
# CHECK: len(s): 2
print("len(s): ", len(s))
# CHECK-LABEL: TEST: testTypeCast
@run
def testTypeCast():
@@ -182,6 +184,7 @@ def testIntegerType():
# CHECK: unsigned: ui64
print("unsigned:", IntegerType.get_unsigned(64))
# CHECK-LABEL: TEST: testIndexType
@run
def testIndexType():
@@ -259,7 +262,8 @@ def testConcreteShapedType():
# CHECK: rank: 2
print("rank:", vector.rank)
# CHECK: whether the shaped type has a static shape: True
print("whether the shaped type has a static shape:", vector.has_static_shape)
print("whether the shaped type has a static shape:",
vector.has_static_shape)
# CHECK: whether the dim-th dimension is dynamic: False
print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
# CHECK: dim size: 3
@@ -311,8 +315,7 @@ def testRankedTensorType():
shape = [2, 3]
loc = Location.unknown()
# CHECK: ranked tensor type: tensor<2x3xf32>
print("ranked tensor type:",
RankedTensorType.get(shape, f32))
print("ranked tensor type:", RankedTensorType.get(shape, f32))
none = NoneType.get()
try:
@@ -477,8 +480,7 @@ def testTupleType():
@run
def testFunctionType():
with Context() as ctx:
input_types = [IntegerType.get_signless(32),
IntegerType.get_signless(16)]
input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
result_types = [IndexType.get()]
func = FunctionType.get(input_types, result_types)
# CHECK: INPUTS: [Type(i32), Type(i16)]
@@ -509,3 +511,91 @@ def testShapedTypeConstants():
print(type(ShapedType.get_dynamic_size()))
# CHECK: <class 'int'>
print(type(ShapedType.get_dynamic_stride_or_offset()))
# CHECK-LABEL: TEST: testTypeIDs
@run
def testTypeIDs():
with Context(), Location.unknown():
f32 = F32Type.get()
types = [
(IntegerType, IntegerType.get_signless(16)),
(IndexType, IndexType.get()),
(Float8E4M3FNType, Float8E4M3FNType.get()),
(Float8E5M2Type, Float8E5M2Type.get()),
(Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
(Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
(Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
(BF16Type, BF16Type.get()),
(F16Type, F16Type.get()),
(F32Type, F32Type.get()),
(F64Type, F64Type.get()),
(NoneType, NoneType.get()),
(ComplexType, ComplexType.get(f32)),
(VectorType, VectorType.get([2, 3], f32)),
(RankedTensorType, RankedTensorType.get([2, 3], f32)),
(UnrankedTensorType, UnrankedTensorType.get(f32)),
(MemRefType, MemRefType.get([2, 3], f32)),
(UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))),
(TupleType, TupleType.get_tuple([f32])),
(FunctionType, FunctionType.get([], [])),
(OpaqueType, OpaqueType.get("tensor", "bob")),
]
# CHECK: IntegerType(i16)
# CHECK: IndexType(index)
# CHECK: Float8E4M3FNType(f8E4M3FN)
# CHECK: Float8E5M2Type(f8E5M2)
# CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
# CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
# CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
# CHECK: BF16Type(bf16)
# CHECK: F16Type(f16)
# CHECK: F32Type(f32)
# CHECK: F64Type(f64)
# CHECK: NoneType(none)
# CHECK: ComplexType(complex<f32>)
# CHECK: VectorType(vector<2x3xf32>)
# CHECK: RankedTensorType(tensor<2x3xf32>)
# CHECK: UnrankedTensorType(tensor<*xf32>)
# CHECK: MemRefType(memref<2x3xf32>)
# CHECK: UnrankedMemRefType(memref<*xf32, 2>)
# CHECK: TupleType(tuple<f32>)
# CHECK: FunctionType(() -> ())
# CHECK: OpaqueType(!tensor.bob)
for _, t in types:
print(repr(t))
# Test getTypeIdFunction agrees with
# mlirTypeGetTypeID(self) for an instance.
# CHECK: all equal
for t1, t2 in types:
tid1, tid2 = t1.static_typeid, Type(t2).typeid
assert tid1 == tid2 and hash(tid1) == hash(
tid2), f"expected hash and value equality {t1} {t2}"
else:
print("all equal")
# Test that storing PyTypeID in python dicts
# works as expected.
typeid_dict = dict(types)
assert len(typeid_dict)
# CHECK: all equal
for t1, t2 in typeid_dict.items():
assert t1.static_typeid == t2.typeid and hash(
t1.static_typeid) == hash(
t2.typeid), f"expected hash and value equality {t1} {t2}"
else:
print("all equal")
# CHECK: ShapedType has no typeid.
try:
print(ShapedType.static_typeid)
except AttributeError as e:
print(e)
vector_type = Type.parse("vector<2x3xf32>")
# CHECK: True
print(ShapedType(vector_type).typeid == vector_type.typeid)

View File

@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
@@ -40,6 +41,9 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor,
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("RankedTensorType"));
mlir_value_subclass(m, "TestTensorValue",
mlirTypeIsAPythonTestTestTensorValue)
.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });