mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
[MLIR][python bindings] Expose TypeIDs in python
This diff adds python bindings for `MlirTypeID`. It paves the way for returning accurately typed `Type`s from python APIs (see D150927) and then further along building type "conscious" `Value` APIs (see D150413). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150839
This commit is contained in:
@@ -80,6 +80,8 @@
|
||||
#define MLIR_PYTHON_CAPSULE_PASS_MANAGER \
|
||||
MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr")
|
||||
#define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr")
|
||||
#define MLIR_PYTHON_CAPSULE_TYPEID \
|
||||
MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr")
|
||||
|
||||
/** Attribute on MLIR Python objects that expose their C-API pointer.
|
||||
* This will be a type-specific capsule created as per one of the helpers
|
||||
@@ -268,6 +270,25 @@ static inline MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule) {
|
||||
return op;
|
||||
}
|
||||
|
||||
/** Creates a capsule object encapsulating the raw C-API MlirTypeID.
|
||||
* The returned capsule does not extend or affect ownership of any Python
|
||||
* objects that reference the type in any way.
|
||||
*/
|
||||
static inline PyObject *mlirPythonTypeIDToCapsule(MlirTypeID typeID) {
|
||||
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(typeID),
|
||||
MLIR_PYTHON_CAPSULE_TYPEID, NULL);
|
||||
}
|
||||
|
||||
/** Extracts an MlirTypeID from a capsule as produced from
|
||||
* mlirPythonTypeIDToCapsule. If the capsule is not of the right type, then
|
||||
* a null type is returned (as checked via mlirTypeIDIsNull). In such a
|
||||
* case, the Python APIs will have already set an error. */
|
||||
static inline MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule) {
|
||||
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_TYPEID);
|
||||
MlirTypeID typeID = {ptr};
|
||||
return typeID;
|
||||
}
|
||||
|
||||
/** Creates a capsule object encapsulating the raw C-API MlirType.
|
||||
* The returned capsule does not extend or affect ownership of any Python
|
||||
* objects that reference the type in any way.
|
||||
|
||||
@@ -22,6 +22,9 @@ extern "C" {
|
||||
// Integer types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an Integer type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an integer type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type);
|
||||
|
||||
@@ -56,6 +59,9 @@ MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type);
|
||||
// Index type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an Index type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an index type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type);
|
||||
|
||||
@@ -67,6 +73,9 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
|
||||
// Floating-point types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an Float8E5M2 type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an f8E5M2 type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
|
||||
|
||||
@@ -74,6 +83,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
|
||||
|
||||
/// Returns the typeID of an Float8E4M3FN type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an f8E4M3FN type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
|
||||
|
||||
@@ -81,6 +93,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
|
||||
|
||||
/// Returns the typeID of an Float8E5M2FNUZ type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an f8E5M2FNUZ type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
|
||||
|
||||
@@ -88,6 +103,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx);
|
||||
|
||||
/// Returns the typeID of an Float8E4M3FNUZ type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an f8E4M3FNUZ type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
|
||||
|
||||
@@ -95,6 +113,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
|
||||
|
||||
/// Returns the typeID of an Float8E4M3B11FNUZ type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an f8E4M3B11FNUZ type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
|
||||
|
||||
@@ -102,6 +123,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
|
||||
|
||||
/// Returns the typeID of an BFloat16 type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is a bf16 type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
|
||||
|
||||
@@ -109,6 +133,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx);
|
||||
|
||||
/// Returns the typeID of an Float16 type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an f16 type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type);
|
||||
|
||||
@@ -116,6 +143,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type);
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx);
|
||||
|
||||
/// Returns the typeID of an Float32 type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an f32 type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type);
|
||||
|
||||
@@ -123,6 +153,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type);
|
||||
/// context.
|
||||
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx);
|
||||
|
||||
/// Returns the typeID of an Float64 type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an f64 type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type);
|
||||
|
||||
@@ -134,6 +167,9 @@ MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx);
|
||||
// None type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an None type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is a None type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type);
|
||||
|
||||
@@ -145,6 +181,9 @@ MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx);
|
||||
// Complex type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an Complex type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is a Complex type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type);
|
||||
|
||||
@@ -159,6 +198,9 @@ MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type);
|
||||
// Shaped type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an Shaped type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirShapedTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is a Shaped type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type);
|
||||
|
||||
@@ -202,6 +244,9 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void);
|
||||
// Vector type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an Vector type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is a Vector type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type);
|
||||
|
||||
@@ -226,9 +271,15 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
|
||||
/// Checks whether the given type is a Tensor type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type);
|
||||
|
||||
/// Returns the typeID of an RankedTensor type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is a ranked tensor type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type);
|
||||
|
||||
/// Returns the typeID of an UnrankedTensor type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an unranked tensor type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type);
|
||||
|
||||
@@ -264,9 +315,15 @@ mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType);
|
||||
// Ranked / Unranked MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an MemRef type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is a MemRef type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type);
|
||||
|
||||
/// Returns the typeID of an UnrankedMemRef type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an UnrankedMemRef type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type);
|
||||
|
||||
@@ -326,6 +383,9 @@ mlirUnrankedMemrefGetMemorySpace(MlirType type);
|
||||
// Tuple type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an Tuple type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is a tuple type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type);
|
||||
|
||||
@@ -345,6 +405,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos);
|
||||
// Function type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an Function type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is a function type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type);
|
||||
|
||||
@@ -373,6 +436,9 @@ MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type,
|
||||
// Opaque type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the typeID of an Opaque type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void);
|
||||
|
||||
/// Checks whether the given type is an opaque type.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type);
|
||||
|
||||
|
||||
@@ -236,6 +236,27 @@ struct type_caster<MlirPassManager> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Casts object <-> MlirTypeID.
|
||||
template <>
|
||||
struct type_caster<MlirTypeID> {
|
||||
PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"));
|
||||
bool load(handle src, bool) {
|
||||
py::object capsule = mlirApiObjectToCapsule(src);
|
||||
value = mlirPythonCapsuleToTypeID(capsule.ptr());
|
||||
return !mlirTypeIDIsNull(value);
|
||||
}
|
||||
static handle cast(MlirTypeID v, return_value_policy, handle) {
|
||||
if (v.ptr == nullptr)
|
||||
return py::none();
|
||||
py::object capsule =
|
||||
py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v));
|
||||
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
|
||||
.attr("TypeID")
|
||||
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
|
||||
.release();
|
||||
};
|
||||
};
|
||||
|
||||
/// Casts object <-> MlirType.
|
||||
template <>
|
||||
struct type_caster<MlirType> {
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
@@ -1807,6 +1808,24 @@ PyType PyType::createFromCapsule(py::object capsule) {
|
||||
rawType);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyTypeID.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
py::object PyTypeID::getCapsule() {
|
||||
return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
|
||||
}
|
||||
|
||||
PyTypeID PyTypeID::createFromCapsule(py::object capsule) {
|
||||
MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
|
||||
if (mlirTypeIDIsNull(mlirTypeID))
|
||||
throw py::error_already_set();
|
||||
return PyTypeID(mlirTypeID);
|
||||
}
|
||||
bool PyTypeID::operator==(const PyTypeID &other) const {
|
||||
return mlirTypeIDEqual(typeID, other.typeID);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyValue and subclases.
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -3268,16 +3287,47 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
return printAccum.join();
|
||||
},
|
||||
"Returns the assembly form of the type.")
|
||||
.def("__repr__", [](PyType &self) {
|
||||
// Generally, assembly formats are not printed for __repr__ because
|
||||
// this can cause exceptionally long debug output and exceptions.
|
||||
// However, types are an exception as they typically have compact
|
||||
// assembly forms and printing them is useful.
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append("Type(");
|
||||
mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
.def("__repr__",
|
||||
[](PyType &self) {
|
||||
// Generally, assembly formats are not printed for __repr__ because
|
||||
// this can cause exceptionally long debug output and exceptions.
|
||||
// However, types are an exception as they typically have compact
|
||||
// assembly forms and printing them is useful.
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append("Type(");
|
||||
mlirTypePrint(self, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
})
|
||||
.def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
|
||||
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
|
||||
if (!mlirTypeIDIsNull(mlirTypeID))
|
||||
return mlirTypeID;
|
||||
auto origRepr =
|
||||
pybind11::repr(pybind11::cast(self)).cast<std::string>();
|
||||
throw py::value_error(
|
||||
(origRepr + llvm::Twine(" has no typeid.")).str());
|
||||
});
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyTypeID.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyTypeID>(m, "TypeID", py::module_local())
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
|
||||
// Note, this tests whether the underlying TypeIDs are the same,
|
||||
// not whether the wrapper MlirTypeIDs are the same, nor whether
|
||||
// the Python objects are the same (i.e., PyTypeID is a value type).
|
||||
.def("__eq__",
|
||||
[](PyTypeID &self, PyTypeID &other) { return self == other; })
|
||||
.def("__eq__",
|
||||
[](PyTypeID &self, const py::object &other) { return false; })
|
||||
// Note, this gives the hash value of the underlying TypeID, not the
|
||||
// hash value of the Python object, nor the hash value of the
|
||||
// MlirTypeID wrapper.
|
||||
.def("__hash__", [](PyTypeID &self) {
|
||||
return static_cast<size_t>(mlirTypeIDHashValue(self));
|
||||
});
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/IntegerSet.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -826,6 +827,29 @@ private:
|
||||
MlirType type;
|
||||
};
|
||||
|
||||
/// A TypeID provides an efficient and unique identifier for a specific C++
|
||||
/// type. This allows for a C++ type to be compared, hashed, and stored in an
|
||||
/// opaque context. This class wraps around the generic MlirTypeID.
|
||||
class PyTypeID {
|
||||
public:
|
||||
PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
|
||||
// Note, this tests whether the underlying TypeIDs are the same,
|
||||
// not whether the wrapper MlirTypeIDs are the same, nor whether
|
||||
// the PyTypeID objects are the same (i.e., PyTypeID is a value type).
|
||||
bool operator==(const PyTypeID &other) const;
|
||||
operator MlirTypeID() const { return typeID; }
|
||||
MlirTypeID get() { return typeID; }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirTypeID.
|
||||
pybind11::object getCapsule();
|
||||
|
||||
/// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
|
||||
static PyTypeID createFromCapsule(pybind11::object capsule);
|
||||
|
||||
private:
|
||||
MlirTypeID typeID;
|
||||
};
|
||||
|
||||
/// CRTP base classes for Python types that subclass Type and should be
|
||||
/// castable from it (i.e. via something like IntegerType(t)).
|
||||
/// By default, type class hierarchies are one level deep (i.e. a
|
||||
@@ -839,10 +863,14 @@ public:
|
||||
// const char *pyClassName
|
||||
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
|
||||
using IsAFunctionTy = bool (*)(MlirType);
|
||||
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
|
||||
|
||||
PyConcreteType() = default;
|
||||
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
|
||||
: BaseTy(std::move(contextRef), t) {}
|
||||
: BaseTy(std::move(contextRef), t) {
|
||||
pybind11::implicitly_convertible<PyType, DerivedTy>();
|
||||
}
|
||||
PyConcreteType(PyType &orig)
|
||||
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
|
||||
|
||||
@@ -866,6 +894,26 @@ public:
|
||||
return DerivedTy::isaFunction(otherType);
|
||||
},
|
||||
pybind11::arg("other"));
|
||||
cls.def_property_readonly_static(
|
||||
"static_typeid", [](py::object & /*class*/) -> MlirTypeID {
|
||||
if (DerivedTy::getTypeIdFunction)
|
||||
return DerivedTy::getTypeIdFunction();
|
||||
throw SetPyError(PyExc_AttributeError,
|
||||
DerivedTy::pyClassName +
|
||||
llvm::Twine(" has no typeid."));
|
||||
});
|
||||
cls.def_property_readonly("typeid", [](PyType &self) {
|
||||
return py::cast(self).attr("typeid").cast<MlirTypeID>();
|
||||
});
|
||||
cls.def("__repr__", [](DerivedTy &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
printAccum.parts.append(DerivedTy::pyClassName);
|
||||
printAccum.parts.append("(");
|
||||
mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
});
|
||||
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
|
||||
class PyIntegerType : public PyConcreteType<PyIntegerType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirIntegerTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "IntegerType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -89,6 +91,8 @@ public:
|
||||
class PyIndexType : public PyConcreteType<PyIndexType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirIndexTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "IndexType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -107,6 +111,8 @@ public:
|
||||
class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFloat8E4M3FNTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "Float8E4M3FNType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -125,6 +131,8 @@ public:
|
||||
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFloat8E5M2TypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "Float8E5M2Type";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -143,6 +151,8 @@ public:
|
||||
class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFloat8E4M3FNUZTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "Float8E4M3FNUZType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -161,6 +171,8 @@ public:
|
||||
class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFloat8E4M3B11FNUZTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -179,6 +191,8 @@ public:
|
||||
class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFloat8E5M2FNUZTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "Float8E5M2FNUZType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -197,6 +211,8 @@ public:
|
||||
class PyBF16Type : public PyConcreteType<PyBF16Type> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirBFloat16TypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "BF16Type";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -215,6 +231,8 @@ public:
|
||||
class PyF16Type : public PyConcreteType<PyF16Type> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFloat16TypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "F16Type";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -233,6 +251,8 @@ public:
|
||||
class PyF32Type : public PyConcreteType<PyF32Type> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFloat32TypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "F32Type";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -251,6 +271,8 @@ public:
|
||||
class PyF64Type : public PyConcreteType<PyF64Type> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFloat64TypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "F64Type";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -269,6 +291,8 @@ public:
|
||||
class PyNoneType : public PyConcreteType<PyNoneType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirNoneTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "NoneType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -287,6 +311,8 @@ public:
|
||||
class PyComplexType : public PyConcreteType<PyComplexType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirComplexTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "ComplexType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -417,6 +443,8 @@ private:
|
||||
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirVectorTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "VectorType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -442,6 +470,8 @@ class PyRankedTensorType
|
||||
: public PyConcreteType<PyRankedTensorType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirRankedTensorTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "RankedTensorType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -476,6 +506,8 @@ class PyUnrankedTensorType
|
||||
: public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirUnrankedTensorTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "UnrankedTensorType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -498,6 +530,8 @@ public:
|
||||
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirMemRefTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "MemRefType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -550,6 +584,8 @@ class PyUnrankedMemRefType
|
||||
: public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirUnrankedMemRefTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "UnrankedMemRefType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -585,6 +621,8 @@ public:
|
||||
class PyTupleType : public PyConcreteType<PyTupleType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirTupleTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "TupleType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -622,6 +660,8 @@ public:
|
||||
class PyFunctionType : public PyConcreteType<PyFunctionType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirFunctionTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "FunctionType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
@@ -676,6 +716,8 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
|
||||
class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirOpaqueTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "OpaqueType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ using namespace mlir;
|
||||
// Integer types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsAInteger(MlirType type) {
|
||||
return llvm::isa<IntegerType>(unwrap(type));
|
||||
}
|
||||
@@ -58,6 +60,8 @@ bool mlirIntegerTypeIsUnsigned(MlirType type) {
|
||||
// Index type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsAIndex(MlirType type) {
|
||||
return llvm::isa<IndexType>(unwrap(type));
|
||||
}
|
||||
@@ -70,6 +74,10 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
|
||||
// Floating-point types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
|
||||
return wrap(Float8E5M2Type::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsAFloat8E5M2(MlirType type) {
|
||||
return unwrap(type).isFloat8E5M2();
|
||||
}
|
||||
@@ -78,6 +86,10 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
|
||||
return wrap(Float8E4M3FNType::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
|
||||
return unwrap(type).isFloat8E4M3FN();
|
||||
}
|
||||
@@ -86,6 +98,10 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
|
||||
return wrap(Float8E5M2FNUZType::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
|
||||
return unwrap(type).isFloat8E5M2FNUZ();
|
||||
}
|
||||
@@ -94,6 +110,10 @@ MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
|
||||
return wrap(Float8E4M3FNUZType::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
|
||||
return unwrap(type).isFloat8E4M3FNUZ();
|
||||
}
|
||||
@@ -102,6 +122,10 @@ MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
|
||||
return wrap(Float8E4M3B11FNUZType::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
|
||||
return unwrap(type).isFloat8E4M3B11FNUZ();
|
||||
}
|
||||
@@ -110,24 +134,34 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirBFloat16TypeGetTypeID() {
|
||||
return wrap(BFloat16Type::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
|
||||
|
||||
MlirType mlirBF16TypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getBF16(unwrap(ctx)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
|
||||
|
||||
MlirType mlirF16TypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getF16(unwrap(ctx)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
|
||||
|
||||
MlirType mlirF32TypeGet(MlirContext ctx) {
|
||||
return wrap(FloatType::getF32(unwrap(ctx)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
|
||||
|
||||
MlirType mlirF64TypeGet(MlirContext ctx) {
|
||||
@@ -138,6 +172,8 @@ MlirType mlirF64TypeGet(MlirContext ctx) {
|
||||
// None type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsANone(MlirType type) {
|
||||
return llvm::isa<NoneType>(unwrap(type));
|
||||
}
|
||||
@@ -150,6 +186,8 @@ MlirType mlirNoneTypeGet(MlirContext ctx) {
|
||||
// Complex type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsAComplex(MlirType type) {
|
||||
return llvm::isa<ComplexType>(unwrap(type));
|
||||
}
|
||||
@@ -214,6 +252,8 @@ int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
|
||||
// Vector type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsAVector(MlirType type) {
|
||||
return llvm::isa<VectorType>(unwrap(type));
|
||||
}
|
||||
@@ -239,10 +279,18 @@ bool mlirTypeIsATensor(MlirType type) {
|
||||
return llvm::isa<TensorType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirRankedTensorTypeGetTypeID() {
|
||||
return wrap(RankedTensorType::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsARankedTensor(MlirType type) {
|
||||
return llvm::isa<RankedTensorType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirUnrankedTensorTypeGetTypeID() {
|
||||
return wrap(UnrankedTensorType::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsAUnrankedTensor(MlirType type) {
|
||||
return llvm::isa<UnrankedTensorType>(unwrap(type));
|
||||
}
|
||||
@@ -280,6 +328,8 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
|
||||
// Ranked / Unranked MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsAMemRef(MlirType type) {
|
||||
return llvm::isa<MemRefType>(unwrap(type));
|
||||
}
|
||||
@@ -337,6 +387,10 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
|
||||
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
|
||||
}
|
||||
|
||||
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
|
||||
return wrap(UnrankedMemRefType::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsAUnrankedMemRef(MlirType type) {
|
||||
return llvm::isa<UnrankedMemRefType>(unwrap(type));
|
||||
}
|
||||
@@ -362,6 +416,8 @@ MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
|
||||
// Tuple type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsATuple(MlirType type) {
|
||||
return llvm::isa<TupleType>(unwrap(type));
|
||||
}
|
||||
@@ -386,6 +442,10 @@ MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
|
||||
// Function type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirFunctionTypeGetTypeID() {
|
||||
return wrap(FunctionType::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsAFunction(MlirType type) {
|
||||
return llvm::isa<FunctionType>(unwrap(type));
|
||||
}
|
||||
@@ -424,6 +484,8 @@ MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
|
||||
// Opaque type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); }
|
||||
|
||||
bool mlirTypeIsAOpaque(MlirType type) {
|
||||
return llvm::isa<OpaqueType>(unwrap(type));
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._python_test_ops_gen import *
|
||||
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue
|
||||
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
|
||||
|
||||
def register_python_test_dialect(context, load=True):
|
||||
from .._mlir_libs import _mlirPythonTest
|
||||
|
||||
@@ -299,6 +299,14 @@ def testCustomType():
|
||||
|
||||
# The following cast must not assert.
|
||||
b = test.TestType(a)
|
||||
# Instance custom types should have typeids
|
||||
assert isinstance(b.typeid, TypeID)
|
||||
# Subclasses of ir.Type should not have a static_typeid
|
||||
# CHECK: 'TestType' object has no attribute 'static_typeid'
|
||||
try:
|
||||
b.static_typeid
|
||||
except AttributeError as e:
|
||||
print(e)
|
||||
|
||||
i8 = IntegerType.get_signless(8)
|
||||
try:
|
||||
@@ -353,6 +361,12 @@ def testTensorValue():
|
||||
# CHECK: False
|
||||
print(tt.is_null())
|
||||
|
||||
# Classes of custom types that inherit from concrete types should have
|
||||
# static_typeid
|
||||
assert isinstance(test.TestTensorType.static_typeid, TypeID)
|
||||
# And it should be equal to the in-tree concrete type
|
||||
assert test.TestTensorType.static_typeid == t.type.typeid
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: inferReturnTypeComponents
|
||||
@run
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import gc
|
||||
from mlir.ir import *
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
@@ -76,6 +77,7 @@ def testTypeHash():
|
||||
# CHECK: len(s): 2
|
||||
print("len(s): ", len(s))
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testTypeCast
|
||||
@run
|
||||
def testTypeCast():
|
||||
@@ -182,6 +184,7 @@ def testIntegerType():
|
||||
# CHECK: unsigned: ui64
|
||||
print("unsigned:", IntegerType.get_unsigned(64))
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testIndexType
|
||||
@run
|
||||
def testIndexType():
|
||||
@@ -259,7 +262,8 @@ def testConcreteShapedType():
|
||||
# CHECK: rank: 2
|
||||
print("rank:", vector.rank)
|
||||
# CHECK: whether the shaped type has a static shape: True
|
||||
print("whether the shaped type has a static shape:", vector.has_static_shape)
|
||||
print("whether the shaped type has a static shape:",
|
||||
vector.has_static_shape)
|
||||
# CHECK: whether the dim-th dimension is dynamic: False
|
||||
print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
|
||||
# CHECK: dim size: 3
|
||||
@@ -311,8 +315,7 @@ def testRankedTensorType():
|
||||
shape = [2, 3]
|
||||
loc = Location.unknown()
|
||||
# CHECK: ranked tensor type: tensor<2x3xf32>
|
||||
print("ranked tensor type:",
|
||||
RankedTensorType.get(shape, f32))
|
||||
print("ranked tensor type:", RankedTensorType.get(shape, f32))
|
||||
|
||||
none = NoneType.get()
|
||||
try:
|
||||
@@ -477,8 +480,7 @@ def testTupleType():
|
||||
@run
|
||||
def testFunctionType():
|
||||
with Context() as ctx:
|
||||
input_types = [IntegerType.get_signless(32),
|
||||
IntegerType.get_signless(16)]
|
||||
input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
|
||||
result_types = [IndexType.get()]
|
||||
func = FunctionType.get(input_types, result_types)
|
||||
# CHECK: INPUTS: [Type(i32), Type(i16)]
|
||||
@@ -509,3 +511,91 @@ def testShapedTypeConstants():
|
||||
print(type(ShapedType.get_dynamic_size()))
|
||||
# CHECK: <class 'int'>
|
||||
print(type(ShapedType.get_dynamic_stride_or_offset()))
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testTypeIDs
|
||||
@run
|
||||
def testTypeIDs():
|
||||
with Context(), Location.unknown():
|
||||
f32 = F32Type.get()
|
||||
|
||||
types = [
|
||||
(IntegerType, IntegerType.get_signless(16)),
|
||||
(IndexType, IndexType.get()),
|
||||
(Float8E4M3FNType, Float8E4M3FNType.get()),
|
||||
(Float8E5M2Type, Float8E5M2Type.get()),
|
||||
(Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
|
||||
(Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
|
||||
(Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
|
||||
(BF16Type, BF16Type.get()),
|
||||
(F16Type, F16Type.get()),
|
||||
(F32Type, F32Type.get()),
|
||||
(F64Type, F64Type.get()),
|
||||
(NoneType, NoneType.get()),
|
||||
(ComplexType, ComplexType.get(f32)),
|
||||
(VectorType, VectorType.get([2, 3], f32)),
|
||||
(RankedTensorType, RankedTensorType.get([2, 3], f32)),
|
||||
(UnrankedTensorType, UnrankedTensorType.get(f32)),
|
||||
(MemRefType, MemRefType.get([2, 3], f32)),
|
||||
(UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))),
|
||||
(TupleType, TupleType.get_tuple([f32])),
|
||||
(FunctionType, FunctionType.get([], [])),
|
||||
(OpaqueType, OpaqueType.get("tensor", "bob")),
|
||||
]
|
||||
|
||||
# CHECK: IntegerType(i16)
|
||||
# CHECK: IndexType(index)
|
||||
# CHECK: Float8E4M3FNType(f8E4M3FN)
|
||||
# CHECK: Float8E5M2Type(f8E5M2)
|
||||
# CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
|
||||
# CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
|
||||
# CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
|
||||
# CHECK: BF16Type(bf16)
|
||||
# CHECK: F16Type(f16)
|
||||
# CHECK: F32Type(f32)
|
||||
# CHECK: F64Type(f64)
|
||||
# CHECK: NoneType(none)
|
||||
# CHECK: ComplexType(complex<f32>)
|
||||
# CHECK: VectorType(vector<2x3xf32>)
|
||||
# CHECK: RankedTensorType(tensor<2x3xf32>)
|
||||
# CHECK: UnrankedTensorType(tensor<*xf32>)
|
||||
# CHECK: MemRefType(memref<2x3xf32>)
|
||||
# CHECK: UnrankedMemRefType(memref<*xf32, 2>)
|
||||
# CHECK: TupleType(tuple<f32>)
|
||||
# CHECK: FunctionType(() -> ())
|
||||
# CHECK: OpaqueType(!tensor.bob)
|
||||
for _, t in types:
|
||||
print(repr(t))
|
||||
|
||||
# Test getTypeIdFunction agrees with
|
||||
# mlirTypeGetTypeID(self) for an instance.
|
||||
# CHECK: all equal
|
||||
for t1, t2 in types:
|
||||
tid1, tid2 = t1.static_typeid, Type(t2).typeid
|
||||
assert tid1 == tid2 and hash(tid1) == hash(
|
||||
tid2), f"expected hash and value equality {t1} {t2}"
|
||||
else:
|
||||
print("all equal")
|
||||
|
||||
# Test that storing PyTypeID in python dicts
|
||||
# works as expected.
|
||||
typeid_dict = dict(types)
|
||||
assert len(typeid_dict)
|
||||
|
||||
# CHECK: all equal
|
||||
for t1, t2 in typeid_dict.items():
|
||||
assert t1.static_typeid == t2.typeid and hash(
|
||||
t1.static_typeid) == hash(
|
||||
t2.typeid), f"expected hash and value equality {t1} {t2}"
|
||||
else:
|
||||
print("all equal")
|
||||
|
||||
# CHECK: ShapedType has no typeid.
|
||||
try:
|
||||
print(ShapedType.static_typeid)
|
||||
except AttributeError as e:
|
||||
print(e)
|
||||
|
||||
vector_type = Type.parse("vector<2x3xf32>")
|
||||
# CHECK: True
|
||||
print(ShapedType(vector_type).typeid == vector_type.typeid)
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PythonTestCAPI.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
@@ -40,6 +41,9 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
|
||||
return cls(mlirPythonTestTestTypeGet(ctx));
|
||||
},
|
||||
py::arg("cls"), py::arg("context") = py::none());
|
||||
mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor,
|
||||
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
|
||||
.attr("RankedTensorType"));
|
||||
mlir_value_subclass(m, "TestTensorValue",
|
||||
mlirTypeIsAPythonTestTestTensorValue)
|
||||
.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
|
||||
|
||||
Reference in New Issue
Block a user