[mlir][python] Create all missing attribute builders.

This patch adds attribute builders for all buildable attributes from the
builtin dialect that did not previously have any. These builders can be
used to construct attributes of a particular type identified by a string
from a Python argument without knowing the details of how to pass that
Python argument to the attribute constructor. This is used, for example,
in the generated code of the Python bindings of ops.

The list of "all" attributes was produced with:

(
  grep -h "ods_ir.AttrBuilder.get" $(find ../build/ -name "*_ops_gen.py") \
    | cut -f2 -d"'"
  git grep -ho "^def [a-zA-Z0-9_]*" -- include/mlir/IR/CommonAttrConstraints.td \
    | cut -f2 -d" "
) | sort -u

Then, I only retained those that had an occurence in
`mlir/include/mlir/IR`. In particular, this drops many dialect-specific
attributes; registering those builders is something that those dialects
should do. Finally, I removed those attrbiutes that had a match in
`mlir/python/mlir/ir.py` already and implemented the remaining ones. The
only ones that still miss a builder now are the following:

* Represent more than one possible attribute type:
  - `Any.*Attr` (9x)
  - `IntNonNegative`
  - `IntPositive`
  - `IsNullAttr`
  - `ElementsAttr`
* I am not sure what "constant attributes" are:
  - `ConstBoolAttrFalse`
  - `ConstBoolAttrTrue`
  - `ConstUnitAttr`
* `Location` not exposed by Python bindings:
  - `LocationArrayAttr`
  - `LocationAttr`
* `get` function not implemented in Python bindings:
  - `StringElementsAttr`

This patch also fixes a compilation problem with
`I64SmallVectorArrayAttr`.

Reviewed By: makslevental, rkayaith

Differential Revision: https://reviews.llvm.org/D159403
This commit is contained in:
Ingo Müller
2023-09-01 09:11:35 +00:00
parent d26c78b2ad
commit ca23c933bd
4 changed files with 280 additions and 31 deletions

View File

@@ -16,16 +16,36 @@ def register_attribute_builder(kind, replace=False):
return decorator_builder
@register_attribute_builder("AffineMapAttr")
def _affineMapAttr(x, context):
return AffineMapAttr.get(x)
@register_attribute_builder("BoolAttr")
def _boolAttr(x, context):
return BoolAttr.get(x, context=context)
@register_attribute_builder("DictionaryAttr")
def _dictAttr(x, context):
return DictAttr.get(x, context=context)
@register_attribute_builder("IndexAttr")
def _indexAttr(x, context):
return IntegerAttr.get(IndexType.get(context=context), x)
@register_attribute_builder("I1Attr")
def _i1Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(1, context=context), x)
@register_attribute_builder("I8Attr")
def _i8Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(8, context=context), x)
@register_attribute_builder("I16Attr")
def _i16Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
@@ -41,6 +61,16 @@ def _i64Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
@register_attribute_builder("SI1Attr")
def _si1Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(1, context=context), x)
@register_attribute_builder("SI8Attr")
def _i8Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(8, context=context), x)
@register_attribute_builder("SI16Attr")
def _si16Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
@@ -51,6 +81,36 @@ def _si32Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
@register_attribute_builder("SI64Attr")
def _si64Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(64, context=context), x)
@register_attribute_builder("UI1Attr")
def _ui1Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(1, context=context), x)
@register_attribute_builder("UI8Attr")
def _i8Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x)
@register_attribute_builder("UI16Attr")
def _ui16Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(16, context=context), x)
@register_attribute_builder("UI32Attr")
def _ui32Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(32, context=context), x)
@register_attribute_builder("UI64Attr")
def _ui64Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(64, context=context), x)
@register_attribute_builder("F32Attr")
def _f32Attr(x, context):
return FloatAttr.get_f32(x, context=context)
@@ -84,11 +144,39 @@ def _flatSymbolRefAttr(x, context):
return FlatSymbolRefAttr.get(x, context=context)
@register_attribute_builder("UnitAttr")
def _unitAttr(x, context):
if x:
return UnitAttr.get(context=context)
else:
return None
@register_attribute_builder("ArrayAttr")
def _arrayAttr(x, context):
return ArrayAttr.get(x, context=context)
@register_attribute_builder("AffineMapArrayAttr")
def _affineMapArrayAttr(x, context):
return ArrayAttr.get([_affineMapAttr(v, context) for v in x])
@register_attribute_builder("BoolArrayAttr")
def _boolArrayAttr(x, context):
return ArrayAttr.get([_boolAttr(v, context) for v in x])
@register_attribute_builder("DictArrayAttr")
def _dictArrayAttr(x, context):
return ArrayAttr.get([_dictAttr(v, context) for v in x])
@register_attribute_builder("FlatSymbolRefArrayAttr")
def _flatSymbolRefArrayAttr(x, context):
return ArrayAttr.get([_flatSymbolRefAttr(v, context) for v in x])
@register_attribute_builder("I32ArrayAttr")
def _i32ArrayAttr(x, context):
return ArrayAttr.get([_i32Attr(v, context) for v in x])
@@ -99,6 +187,16 @@ def _i64ArrayAttr(x, context):
return ArrayAttr.get([_i64Attr(v, context) for v in x])
@register_attribute_builder("I64SmallVectorArrayAttr")
def _i64SmallVectorArrayAttr(x, context):
return _i64ArrayAttr(x, context=context)
@register_attribute_builder("IndexListArrayAttr")
def _indexListArrayAttr(x, context):
return ArrayAttr.get([_i64ArrayAttr(v, context) for v in x])
@register_attribute_builder("F32ArrayAttr")
def _f32ArrayAttr(x, context):
return ArrayAttr.get([_f32Attr(v, context) for v in x])
@@ -109,6 +207,41 @@ def _f64ArrayAttr(x, context):
return ArrayAttr.get([_f64Attr(v, context) for v in x])
@register_attribute_builder("StrArrayAttr")
def _strArrayAttr(x, context):
return ArrayAttr.get([_stringAttr(v, context) for v in x])
@register_attribute_builder("SymbolRefArrayAttr")
def _symbolRefArrayAttr(x, context):
return ArrayAttr.get([_symbolRefAttr(v, context) for v in x])
@register_attribute_builder("DenseF32ArrayAttr")
def _denseF32ArrayAttr(x, context):
return DenseF32ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseF64ArrayAttr")
def _denseF64ArrayAttr(x, context):
return DenseF64ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseI8ArrayAttr")
def _denseI8ArrayAttr(x, context):
return DenseI8ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseI16ArrayAttr")
def _denseI16ArrayAttr(x, context):
return DenseI16ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseI32ArrayAttr")
def _denseI32ArrayAttr(x, context):
return DenseI32ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseI64ArrayAttr")
def _denseI64ArrayAttr(x, context):
return DenseI64ArrayAttr.get(x, context=context)
@@ -132,6 +265,30 @@ def _typeArrayAttr(x, context):
try:
import numpy as np
@register_attribute_builder("F64ElementsAttr")
def _f64ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64),
type=F64Type.get(context=context),
context=context,
)
@register_attribute_builder("I32ElementsAttr")
def _i32ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int32),
type=IntegerType.get_signed(32, context=context),
context=context,
)
@register_attribute_builder("I64ElementsAttr")
def _i64ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64),
type=IntegerType.get_signed(64, context=context),
context=context,
)
@register_attribute_builder("IndexElementsAttr")
def _indexElementsAttr(x, context):
return DenseElementsAttr.get(