[mlir][python] Expose fp8 types with pybind.

Expose fp8 types with pybind.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D140746
This commit is contained in:
Qiao Zhang
2023-01-03 19:06:30 +00:00
committed by Mehdi Amini
parent 2671aa7e84
commit 4d29f6ed6e
3 changed files with 58 additions and 0 deletions

View File

@@ -102,6 +102,42 @@ public:
}
};
/// Floating Point Type subclass - Float8E4M3FNType.
class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
static constexpr const char *pyClassName = "Float8E4M3FNType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
return PyFloat8E4M3FNType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
}
};
/// Floating Point Type subclass - Float8M5E2Type.
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
static constexpr const char *pyClassName = "Float8E5M2Type";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirFloat8E5M2TypeGet(context->get());
return PyFloat8E5M2Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e5m2 type.");
}
};
/// Floating Point Type subclass - BF16Type.
class PyBF16Type : public PyConcreteType<PyBF16Type> {
public:
@@ -663,6 +699,8 @@ public:
void mlir::python::populateIRTypes(py::module &m) {
PyIntegerType::bind(m);
PyIndexType::bind(m);
PyFloat8E4M3FNType::bind(m);
PyFloat8E5M2Type::bind(m);
PyBF16Type::bind(m);
PyF16Type::bind(m);
PyF32Type::bind(m);

View File

@@ -50,6 +50,8 @@ __all__ = [
"DiagnosticHandler",
"DiagnosticSeverity",
"DictAttr",
"Float8E4M3FNType",
"Float8E5M2Type",
"F16Type",
"F32Type",
"F64Type",
@@ -577,6 +579,20 @@ class DictAttr(Attribute):
@property
def type(self) -> Type: ...
class Float8E4M3FNType(Type):
def __init__(self, cast_from_type: Type) -> None: ...
@staticmethod
def get(*args, **kwargs) -> Float8E4M3FNType: ...
@staticmethod
def isinstance(arg: Any) -> bool: ...
class Float8E5M2Type(Type):
def __init__(self, cast_from_type: Type) -> None: ...
@staticmethod
def get(*args, **kwargs) -> Float8E5M2Type: ...
@staticmethod
def isinstance(arg: Any) -> bool: ...
# TODO: Auto-generated. Audit and fix.
class F16Type(Type):
def __init__(self, cast_from_type: Type) -> None: ...

View File

@@ -193,6 +193,10 @@ def testIndexType():
@run
def testFloatType():
with Context():
# CHECK: float: f8E4M3FN
print("float:", Float8E4M3FNType.get())
# CHECK: float: f8E5M2
print("float:", Float8E5M2Type.get())
# CHECK: float: bf16
print("float:", BF16Type.get())
# CHECK: float: f16