[mlir] python enum bindings generator

Add an ODS (tablegen) backend to generate Python enum classes and
attribute builders for enum attributes defined in ODS. This will allow
us to keep the enum attribute definitions in sync between C++ and
Python, as opposed to handwritten enum classes in Python that may end up
using mismatching values. This also makes autogenerated bindings more
convenient even in absence of mixins.

Use this backend for the transform dialect failure propagation mode enum
attribute as demonstration.

Reviewed By: ingomueller-net

Differential Revision: https://reviews.llvm.org/D156553
This commit is contained in:
Alex Zinenko
2023-07-28 16:03:33 +00:00
parent 235390d930
commit 1f8618f88c
7 changed files with 272 additions and 79 deletions

View File

@@ -134,6 +134,15 @@ declare_mlir_dialect_python_bindings(
_mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform)
set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td")
mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings)
add_public_tablegen_target(MLIRTransformDialectPyEnumGen)
declare_mlir_python_sources(
MLIRPythonSources.Dialects.transform.enum_gen
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
SOURCES "dialects/_transform_enum_gen.py")
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

View File

@@ -15,68 +15,66 @@ from typing import Optional, Sequence, Union
class CastOp:
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
loc=None,
ip=None,
):
super().__init__(
result_type, _get_op_result_or_value(target), loc=loc, ip=ip
)
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
loc=None,
ip=None,
):
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
class ApplyPatternsOp:
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
loc=None,
ip=None,
):
operands = []
operands.append(_get_op_result_or_value(target))
super().__init__(
self.build_generic(
attributes={},
results=[],
operands=operands,
successors=None,
regions=None,
loc=loc,
ip=ip,
)
)
self.regions[0].blocks.append()
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
loc=None,
ip=None,
):
operands = []
operands.append(_get_op_result_or_value(target))
super().__init__(
self.build_generic(attributes={},
results=[],
operands=operands,
successors=None,
regions=None,
loc=loc,
ip=ip))
self.regions[0].blocks.append()
@property
def patterns(self) -> Block:
return self.regions[0].blocks[0]
@property
def patterns(self) -> Block:
return self.regions[0].blocks[0]
class testGetParentOp:
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
isolated_from_above: bool = False,
op_name: Optional[str] = None,
deduplicate: bool = False,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
isolated_from_above=isolated_from_above,
op_name=op_name,
deduplicate=deduplicate,
loc=loc,
ip=ip,
)
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
isolated_from_above: bool = False,
op_name: Optional[str] = None,
deduplicate: bool = False,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
isolated_from_above=isolated_from_above,
op_name=op_name,
deduplicate=deduplicate,
loc=loc,
ip=ip,
)
class MergeHandlesOp:
@@ -130,12 +128,6 @@ class SequenceOp:
else None
)
root_type = root.type if not isinstance(target, Type) else target
if not isinstance(failure_propagation_mode, Attribute):
failure_propagation_mode_attr = IntegerAttr.get(
IntegerType.get_signless(32), failure_propagation_mode._as_int()
)
else:
failure_propagation_mode_attr = failure_propagation_mode
if extra_bindings is None:
extra_bindings = []
@@ -152,7 +144,7 @@ class SequenceOp:
super().__init__(
results_=results,
failure_propagation_mode=failure_propagation_mode_attr,
failure_propagation_mode=failure_propagation_mode,
root=root,
extra_bindings=extra_bindings,
)

View File

@@ -2,22 +2,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from enum import Enum
class FailurePropagationMode(Enum):
"""Propagation mode for silenceable errors."""
PROPAGATE = 1
SUPPRESS = 2
def _as_int(self):
if self is FailurePropagationMode.PROPAGATE:
return 1
assert self is FailurePropagationMode.SUPPRESS
return 2
from .._transform_enum_gen import *
from .._transform_ops_gen import *
from ..._mlir_libs._mlirDialectsTransform import *

View File

@@ -0,0 +1,57 @@
// RUN: mlir-tblgen -gen-python-enum-bindings %s -I %S/../../include | FileCheck %s
include "mlir/IR/EnumAttr.td"
// CHECK: Autogenerated by mlir-tblgen; don't manually edit.
// CHECK: from enum import Enum
// CHECK: from ._ods_common import _cext as _ods_cext
// CHECK: _ods_ir = _ods_cext.ir
def One : I32EnumAttrCase<"CaseOne", 1, "one">;
def Two : I32EnumAttrCase<"CaseTwo", 2, "two">;
def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two]>;
// CHECK: def _register_attribute_builder(kind):
// CHECK: def decorator_builder(func):
// CHECK: _ods_ir.AttrBuilder.insert(kind, func)
// CHECK: return func
// CHECK: return decorator_builder
// CHECK-LABEL: class MyEnum(Enum):
// CHECK: """An example 32-bit enum"""
// CHECK: CASE_ONE = 1
// CHECK: CASE_TWO = 2
// CHECK: def _as_int(self):
// CHECK: if self is MyEnum.CASE_ONE:
// CHECK: return 1
// CHECK: if self is MyEnum.CASE_TWO:
// CHECK: return 2
// CHECK: assert False, "Unknown MyEnum enum entry."
def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">;
def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>;
// CHECK: @_register_attribute_builder("MyEnum")
// CHECK: def _my_enum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), x._as_int())
// CHECK-LABEL: class MyEnum64(Enum):
// CHECK: """An example 64-bit enum"""
// CHECK: CASE_ONE64 = 1
// CHECK: CASE_TWO64 = 2
// CHECK: def _as_int(self):
// CHECK: if self is MyEnum64.CASE_ONE64:
// CHECK: return 1
// CHECK: if self is MyEnum64.CASE_TWO64:
// CHECK: return 2
// CHECK: assert False, "Unknown MyEnum64 enum entry."
// CHECK: @_register_attribute_builder("MyEnum64")
// CHECK: def _my_enum64(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), x._as_int())

View File

@@ -14,6 +14,7 @@ add_tablegen(mlir-tblgen MLIR
DialectGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
EnumPythonBindingGen.cpp
FormatGen.cpp
LLVMIRConversionGen.cpp
LLVMIRIntrinsicGen.cpp

View File

@@ -0,0 +1,130 @@
//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
// generate the corresponding Python binding classes.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
/// File header and includes.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from enum import Enum
from ._ods_common import _cext as _ods_cext
_ods_ir = _ods_cext.ir
# Convenience decorator for registering user-friendly Attribute builders.
def _register_attribute_builder(kind):
def decorator_builder(func):
_ods_ir.AttrBuilder.insert(kind, func)
return func
return decorator_builder
)Py";
/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
static std::string makePythonEnumCaseName(StringRef name) {
return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper();
}
/// Emits the Python class for the given enum.
static void emitEnumClass(StringRef enumName, StringRef description,
ArrayRef<EnumAttrCase> cases, raw_ostream &os) {
os << llvm::formatv("class {0}(Enum):\n", enumName);
if (!description.empty())
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", description);
os << "\n";
for (const EnumAttrCase &enumCase : cases) {
os << llvm::formatv(" {0} = {1}\n",
makePythonEnumCaseName(enumCase.getSymbol()),
enumCase.getValue());
}
os << "\n";
os << llvm::formatv(" def _as_int(self):\n");
for (const EnumAttrCase &enumCase : cases) {
os << llvm::formatv(" if self is {0}.{1}:\n", enumName,
makePythonEnumCaseName(enumCase.getSymbol()));
os << llvm::formatv(" return {0}\n", enumCase.getValue());
}
os << llvm::formatv(" assert False, \"Unknown {0} enum entry.\"\n\n\n",
enumName);
}
/// Attempts to extract the bitwidth B from string "uintB_t" describing the
/// type. This bitwidth information is not readily available in ODS. Returns
/// `false` on success, `true` on failure.
static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
if (!uintType.consume_front("uint"))
return true;
if (!uintType.consume_back("_t"))
return true;
return uintType.getAsInteger(/*Radix=*/10, bitwidth);
}
/// Emits an attribute builder for the given enum attribute to support automatic
/// conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
int64_t bitwidth;
if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
llvm::errs() << "failed to identify bitwidth of "
<< enumAttr.getUnderlyingType();
return true;
}
os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n",
enumAttr.getAttrDefName());
os << llvm::formatv(
"def _{0}(x, context):\n",
llvm::convertToSnakeFromCamelCase(enumAttr.getAttrDefName()));
os << llvm::formatv(
" return "
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
"context=context), x._as_int())\n\n",
bitwidth);
return false;
}
/// Emits Python bindings for all enums in the record keeper. Returns
/// `false` on success, `true` on failure.
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
os << fileHeader;
std::vector<llvm::Record *> defs =
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
for (const llvm::Record *def : defs) {
EnumAttr enumAttr(*def);
if (enumAttr.isBitEnum()) {
llvm::errs() << "bit enums not supported\n";
return true;
}
emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(),
enumAttr.getAllCases(), os);
emitAttributeBuilder(enumAttr, os);
}
return false;
}
// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
genPythonEnumBindings("gen-python-enum-bindings",
"Generate Python bindings for enum attributes",
&emitPythonEnums);

View File

@@ -732,6 +732,25 @@ filegroup(
# Transform dialect and extensions.
##---------------------------------------------------------------------------##
gentbl_filegroup(
name = "TransformEnumPyGen",
tbl_outs = [
(
["-gen-python-enum-bindings"],
"mlir/dialects/_transform_enum_gen.py",
),
],
tblgen = "//mlir:mlir-tblgen",
td_file = "mlir/dialects/TransformOps.td",
deps = [
"//mlir:CallInterfacesTdFiles",
"//mlir:FunctionInterfacesTdFiles",
"//mlir:OpBaseTdFiles",
"//mlir:TransformDialectTdFiles",
],
)
gentbl_filegroup(
name = "TransformOpsPyGen",
tbl_outs = [
@@ -898,6 +917,7 @@ filegroup(
":MemRefTransformOpsPyGen",
":PDLTransformOpsPyGen",
":StructuredTransformOpsPyGen",
":TransformEnumPyGen",
":TransformOpsPyGen",
],
)