mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 23:25:14 +08:00
[mlir][python] Expose fp8 types with pybind.
Expose fp8 types with pybind. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D140746
This commit is contained in:
@@ -102,6 +102,42 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Floating Point Type subclass - Float8E4M3FNType.
|
||||
class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
|
||||
static constexpr const char *pyClassName = "Float8E4M3FNType";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get",
|
||||
[](DefaultingPyMlirContext context) {
|
||||
MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
|
||||
return PyFloat8E4M3FNType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
|
||||
}
|
||||
};
|
||||
|
||||
/// Floating Point Type subclass - Float8M5E2Type.
|
||||
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
|
||||
static constexpr const char *pyClassName = "Float8E5M2Type";
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get",
|
||||
[](DefaultingPyMlirContext context) {
|
||||
MlirType t = mlirFloat8E5M2TypeGet(context->get());
|
||||
return PyFloat8E5M2Type(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e5m2 type.");
|
||||
}
|
||||
};
|
||||
|
||||
/// Floating Point Type subclass - BF16Type.
|
||||
class PyBF16Type : public PyConcreteType<PyBF16Type> {
|
||||
public:
|
||||
@@ -663,6 +699,8 @@ public:
|
||||
void mlir::python::populateIRTypes(py::module &m) {
|
||||
PyIntegerType::bind(m);
|
||||
PyIndexType::bind(m);
|
||||
PyFloat8E4M3FNType::bind(m);
|
||||
PyFloat8E5M2Type::bind(m);
|
||||
PyBF16Type::bind(m);
|
||||
PyF16Type::bind(m);
|
||||
PyF32Type::bind(m);
|
||||
|
||||
@@ -50,6 +50,8 @@ __all__ = [
|
||||
"DiagnosticHandler",
|
||||
"DiagnosticSeverity",
|
||||
"DictAttr",
|
||||
"Float8E4M3FNType",
|
||||
"Float8E5M2Type",
|
||||
"F16Type",
|
||||
"F32Type",
|
||||
"F64Type",
|
||||
@@ -577,6 +579,20 @@ class DictAttr(Attribute):
|
||||
@property
|
||||
def type(self) -> Type: ...
|
||||
|
||||
class Float8E4M3FNType(Type):
|
||||
def __init__(self, cast_from_type: Type) -> None: ...
|
||||
@staticmethod
|
||||
def get(*args, **kwargs) -> Float8E4M3FNType: ...
|
||||
@staticmethod
|
||||
def isinstance(arg: Any) -> bool: ...
|
||||
|
||||
class Float8E5M2Type(Type):
|
||||
def __init__(self, cast_from_type: Type) -> None: ...
|
||||
@staticmethod
|
||||
def get(*args, **kwargs) -> Float8E5M2Type: ...
|
||||
@staticmethod
|
||||
def isinstance(arg: Any) -> bool: ...
|
||||
|
||||
# TODO: Auto-generated. Audit and fix.
|
||||
class F16Type(Type):
|
||||
def __init__(self, cast_from_type: Type) -> None: ...
|
||||
|
||||
@@ -193,6 +193,10 @@ def testIndexType():
|
||||
@run
|
||||
def testFloatType():
|
||||
with Context():
|
||||
# CHECK: float: f8E4M3FN
|
||||
print("float:", Float8E4M3FNType.get())
|
||||
# CHECK: float: f8E5M2
|
||||
print("float:", Float8E5M2Type.get())
|
||||
# CHECK: float: bf16
|
||||
print("float:", BF16Type.get())
|
||||
# CHECK: float: f16
|
||||
|
||||
Reference in New Issue
Block a user