[mlir][python] Add python bindings for DenseArrayAttr

This patch adds python bindings for the dense array variants.

Fixes #56975

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D131801
This commit is contained in:
Jeff Niu
2022-08-12 15:41:42 -04:00
parent 405ad84793
commit 619fd8c2ab
5 changed files with 385 additions and 0 deletions

View File

@@ -296,6 +296,61 @@ mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs);
/// shaped type and use its sizes to build a multi-dimensional index.
MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr);
//===----------------------------------------------------------------------===//
// Dense array attribute.
//===----------------------------------------------------------------------===//
/// Checks whether the given attribute is a dense array attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI16Array(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI32Array(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI64Array(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF32Array(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF64Array(MlirAttribute attr);
/// Create a dense array attribute with the given elements.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx,
intptr_t size,
int const *values);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx,
intptr_t size,
int8_t const *values);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx,
intptr_t size,
int16_t const *values);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx,
intptr_t size,
int32_t const *values);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx,
intptr_t size,
int64_t const *values);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx,
intptr_t size,
float const *values);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx,
intptr_t size,
double const *values);
/// Get the size of a dense array.
MLIR_CAPI_EXPORTED intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr);
/// Get an element of a dense array.
MLIR_CAPI_EXPORTED bool mlirDenseBoolArrayGetElement(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED float mlirDenseF32ArrayGetElement(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr,
intptr_t pos);
//===----------------------------------------------------------------------===//
// Dense elements attribute.
//===----------------------------------------------------------------------===//

View File

@@ -110,6 +110,161 @@ static T pyTryCast(py::handle object) {
}
}
/// A python-wrapped dense array attribute with an element type and a derived
/// implementation class.
template <typename EltTy, typename DerivedT>
class PyDenseArrayAttribute
: public PyConcreteAttribute<PyDenseArrayAttribute<EltTy, DerivedT>> {
public:
static constexpr typename PyConcreteAttribute<
PyDenseArrayAttribute<EltTy, DerivedT>>::IsAFunctionTy isaFunction =
DerivedT::isaFunction;
static constexpr const char *pyClassName = DerivedT::pyClassName;
using PyConcreteAttribute<
PyDenseArrayAttribute<EltTy, DerivedT>>::PyConcreteAttribute;
/// Iterator over the integer elements of a dense array.
class PyDenseArrayIterator {
public:
PyDenseArrayIterator(PyAttribute attr) : attr(attr) {}
/// Return a copy of the iterator.
PyDenseArrayIterator dunderIter() { return *this; }
/// Return the next element.
EltTy dunderNext() {
// Throw if the index has reached the end.
if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
throw py::stop_iteration();
return DerivedT::getElement(attr.get(), nextIndex++);
}
/// Bind the iterator class.
static void bind(py::module &m) {
py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
py::module_local())
.def("__iter__", &PyDenseArrayIterator::dunderIter)
.def("__next__", &PyDenseArrayIterator::dunderNext);
}
private:
/// The referenced dense array attribute.
PyAttribute attr;
/// The next index to read.
int nextIndex = 0;
};
/// Get the element at the given index.
EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
/// Bind the attribute class.
static void bindDerived(typename PyConcreteAttribute<
PyDenseArrayAttribute<EltTy, DerivedT>>::ClassTy &c) {
// Bind the constructor.
c.def_static(
"get",
[](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
MlirAttribute attr =
DerivedT::getAttribute(ctx->get(), values.size(), values.data());
return PyDenseArrayAttribute<EltTy, DerivedT>(ctx->getRef(), attr);
},
py::arg("values"), py::arg("context") = py::none(),
"Gets a uniqued dense array attribute");
// Bind the array methods.
c.def("__getitem__",
[](PyDenseArrayAttribute<EltTy, DerivedT> &arr, intptr_t i) {
if (i >= mlirDenseArrayGetNumElements(arr))
throw py::index_error("DenseArray index out of range");
return arr.getItem(i);
});
c.def("__len__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
return mlirDenseArrayGetNumElements(arr);
});
c.def("__iter__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
return PyDenseArrayIterator(arr);
});
// Bind a concat.
c.def("__add__", [](PyDenseArrayAttribute<EltTy, DerivedT> &arr,
py::list extras) {
std::vector<EltTy> values;
intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
values.reserve(numOldElements + py::len(extras));
for (intptr_t i = 0; i < numOldElements; ++i)
values.push_back(arr.getItem(i));
for (py::handle attr : extras)
values.push_back(pyTryCast<EltTy>(attr));
MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
values.size(), values.data());
return PyDenseArrayAttribute<EltTy, DerivedT>(arr.getContext(), attr);
});
}
};
/// Instantiate the python dense array classes.
struct PyDenseBoolArrayAttribute
: public PyDenseArrayAttribute<int, PyDenseBoolArrayAttribute> {
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
static constexpr auto getAttribute = mlirDenseBoolArrayGet;
static constexpr auto getElement = mlirDenseBoolArrayGetElement;
static constexpr const char *pyClassName = "DenseBoolArrayAttr";
static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
using PyDenseArrayAttribute::PyDenseArrayAttribute;
};
struct PyDenseI8ArrayAttribute
: public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
static constexpr auto getAttribute = mlirDenseI8ArrayGet;
static constexpr auto getElement = mlirDenseI8ArrayGetElement;
static constexpr const char *pyClassName = "DenseI8ArrayAttr";
static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
using PyDenseArrayAttribute::PyDenseArrayAttribute;
};
struct PyDenseI16ArrayAttribute
: public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
static constexpr auto getAttribute = mlirDenseI16ArrayGet;
static constexpr auto getElement = mlirDenseI16ArrayGetElement;
static constexpr const char *pyClassName = "DenseI16ArrayAttr";
static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
using PyDenseArrayAttribute::PyDenseArrayAttribute;
};
struct PyDenseI32ArrayAttribute
: public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
static constexpr auto getAttribute = mlirDenseI32ArrayGet;
static constexpr auto getElement = mlirDenseI32ArrayGetElement;
static constexpr const char *pyClassName = "DenseI32ArrayAttr";
static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
using PyDenseArrayAttribute::PyDenseArrayAttribute;
};
struct PyDenseI64ArrayAttribute
: public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
static constexpr auto getAttribute = mlirDenseI64ArrayGet;
static constexpr auto getElement = mlirDenseI64ArrayGetElement;
static constexpr const char *pyClassName = "DenseI64ArrayAttr";
static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
using PyDenseArrayAttribute::PyDenseArrayAttribute;
};
struct PyDenseF32ArrayAttribute
: public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
static constexpr auto getAttribute = mlirDenseF32ArrayGet;
static constexpr auto getElement = mlirDenseF32ArrayGetElement;
static constexpr const char *pyClassName = "DenseF32ArrayAttr";
static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
using PyDenseArrayAttribute::PyDenseArrayAttribute;
};
struct PyDenseF64ArrayAttribute
: public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
static constexpr auto getAttribute = mlirDenseF64ArrayGet;
static constexpr auto getElement = mlirDenseF64ArrayGetElement;
static constexpr const char *pyClassName = "DenseF64ArrayAttr";
static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
using PyDenseArrayAttribute::PyDenseArrayAttribute;
};
class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
@@ -891,6 +1046,22 @@ public:
void mlir::python::populateIRAttributes(py::module &m) {
PyAffineMapAttribute::bind(m);
PyDenseBoolArrayAttribute::bind(m);
PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI8ArrayAttribute::bind(m);
PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI16ArrayAttribute::bind(m);
PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI32ArrayAttribute::bind(m);
PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseI64ArrayAttribute::bind(m);
PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseF32ArrayAttribute::bind(m);
PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
PyDenseF64ArrayAttribute::bind(m);
PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
PyArrayAttribute::bind(m);
PyArrayAttribute::PyArrayAttributeIterator::bind(m);
PyBoolAttribute::bind(m);

View File

@@ -311,6 +311,106 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
return unwrap(attr).cast<ElementsAttr>().getNumElements();
}
//===----------------------------------------------------------------------===//
// Dense array attribute.
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// IsA support.
bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
return unwrap(attr).isa<DenseBoolArrayAttr>();
}
bool mlirAttributeIsADenseI8Array(MlirAttribute attr) {
return unwrap(attr).isa<DenseI8ArrayAttr>();
}
bool mlirAttributeIsADenseI16Array(MlirAttribute attr) {
return unwrap(attr).isa<DenseI16ArrayAttr>();
}
bool mlirAttributeIsADenseI32Array(MlirAttribute attr) {
return unwrap(attr).isa<DenseI32ArrayAttr>();
}
bool mlirAttributeIsADenseI64Array(MlirAttribute attr) {
return unwrap(attr).isa<DenseI64ArrayAttr>();
}
bool mlirAttributeIsADenseF32Array(MlirAttribute attr) {
return unwrap(attr).isa<DenseF32ArrayAttr>();
}
bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
return unwrap(attr).isa<DenseF64ArrayAttr>();
}
//===----------------------------------------------------------------------===//
// Constructors.
MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size,
int const *values) {
SmallVector<bool, 4> elements(values, values + size);
return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements));
}
MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size,
int8_t const *values) {
return wrap(
DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef<int8_t>(values, size)));
}
MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size,
int16_t const *values) {
return wrap(
DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef<int16_t>(values, size)));
}
MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size,
int32_t const *values) {
return wrap(
DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef<int32_t>(values, size)));
}
MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size,
int64_t const *values) {
return wrap(
DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef<int64_t>(values, size)));
}
MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size,
float const *values) {
return wrap(
DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef<float>(values, size)));
}
MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
double const *values) {
return wrap(
DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef<double>(values, size)));
}
//===----------------------------------------------------------------------===//
// Accessors.
intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
return unwrap(attr).cast<DenseArrayBaseAttr>().size();
}
//===----------------------------------------------------------------------===//
// Indexed accessors.
bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseBoolArrayAttr>()[pos];
}
int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseI8ArrayAttr>()[pos];
}
int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseI16ArrayAttr>()[pos];
}
int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseI32ArrayAttr>()[pos];
}
int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseI64ArrayAttr>()[pos];
}
float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseF32ArrayAttr>()[pos];
}
double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseF64ArrayAttr>()[pos];
}
//===----------------------------------------------------------------------===//
// Dense elements attribute.
//===----------------------------------------------------------------------===//

View File

@@ -1186,6 +1186,40 @@ int printBuiltinAttributes(MlirContext ctx) {
mlirAttributeDump(sparseAttr);
// CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32>
MlirAttribute boolArray = mlirDenseBoolArrayGet(ctx, 2, bools);
MlirAttribute int8Array = mlirDenseI8ArrayGet(ctx, 2, ints8);
MlirAttribute int16Array = mlirDenseI16ArrayGet(ctx, 2, ints16);
MlirAttribute int32Array = mlirDenseI32ArrayGet(ctx, 2, ints32);
MlirAttribute int64Array = mlirDenseI64ArrayGet(ctx, 2, ints64);
MlirAttribute floatArray = mlirDenseF32ArrayGet(ctx, 2, floats);
MlirAttribute doubleArray = mlirDenseF64ArrayGet(ctx, 2, doubles);
if (!mlirAttributeIsADenseBoolArray(boolArray) ||
!mlirAttributeIsADenseI8Array(int8Array) ||
!mlirAttributeIsADenseI16Array(int16Array) ||
!mlirAttributeIsADenseI32Array(int32Array) ||
!mlirAttributeIsADenseI64Array(int64Array) ||
!mlirAttributeIsADenseF32Array(floatArray) ||
!mlirAttributeIsADenseF64Array(doubleArray))
return 19;
if (mlirDenseArrayGetNumElements(boolArray) != 2 ||
mlirDenseArrayGetNumElements(int8Array) != 2 ||
mlirDenseArrayGetNumElements(int16Array) != 2 ||
mlirDenseArrayGetNumElements(int32Array) != 2 ||
mlirDenseArrayGetNumElements(int64Array) != 2 ||
mlirDenseArrayGetNumElements(floatArray) != 2 ||
mlirDenseArrayGetNumElements(doubleArray) != 2)
return 20;
if (mlirDenseBoolArrayGetElement(boolArray, 1) != 1 ||
mlirDenseI8ArrayGetElement(int8Array, 1) != 1 ||
mlirDenseI16ArrayGetElement(int16Array, 1) != 1 ||
mlirDenseI32ArrayGetElement(int32Array, 1) != 1 ||
mlirDenseI64ArrayGetElement(int64Array, 1) != 1 ||
fabsf(mlirDenseF32ArrayGetElement(floatArray, 1) - 1.0f) > 1E-6f ||
fabs(mlirDenseF64ArrayGetElement(doubleArray, 1) - 1.0) > 1E-6)
return 21;
return 0;
}

View File

@@ -1,8 +1,10 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
@@ -319,6 +321,29 @@ def testDenseIntAttr():
print(ShapedType(a.type).element_type)
@run
def testDenseArrayGetItem():
def print_item(AttrClass, attr_asm):
attr = AttrClass(Attribute.parse(attr_asm))
print(f"{len(attr)}: {attr[0]}, {attr[1]}")
with Context():
# CHECK: 2: 0, 1
print_item(DenseBoolArrayAttr, "array<i1: false, true>")
# CHECK: 2: 2, 3
print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
# CHECK: 2: 4, 5
print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
# CHECK: 2: 6, 7
print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
# CHECK: 2: 8, 9
print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
# CHECK: 2: 1.{{0+}}, 2.{{0+}}
print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
# CHECK: 2: 3.{{0+}}, 4.{{0+}}
print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
@run
def testDenseIntAttrGetItem():