mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 06:40:01 +08:00
[mlir] python enum bindings generator
Add an ODS (tablegen) backend to generate Python enum classes and attribute builders for enum attributes defined in ODS. This will allow us to keep the enum attribute definitions in sync between C++ and Python, as opposed to handwritten enum classes in Python that may end up using mismatching values. This also makes autogenerated bindings more convenient even in absence of mixins. Use this backend for the transform dialect failure propagation mode enum attribute as demonstration. Reviewed By: ingomueller-net Differential Revision: https://reviews.llvm.org/D156553
This commit is contained in:
@@ -134,6 +134,15 @@ declare_mlir_dialect_python_bindings(
|
||||
_mlir_libs/_mlir/dialects/transform/__init__.pyi
|
||||
DIALECT_NAME transform)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td")
|
||||
mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings)
|
||||
add_public_tablegen_target(MLIRTransformDialectPyEnumGen)
|
||||
declare_mlir_python_sources(
|
||||
MLIRPythonSources.Dialects.transform.enum_gen
|
||||
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
|
||||
SOURCES "dialects/_transform_enum_gen.py")
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
|
||||
@@ -15,68 +15,66 @@ from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
class CastOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type, _get_op_result_or_value(target), loc=loc, ip=ip
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
|
||||
|
||||
|
||||
class ApplyPatternsOp:
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
operands = []
|
||||
operands.append(_get_op_result_or_value(target))
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
attributes={},
|
||||
results=[],
|
||||
operands=operands,
|
||||
successors=None,
|
||||
regions=None,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
self.regions[0].blocks.append()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
operands = []
|
||||
operands.append(_get_op_result_or_value(target))
|
||||
super().__init__(
|
||||
self.build_generic(attributes={},
|
||||
results=[],
|
||||
operands=operands,
|
||||
successors=None,
|
||||
regions=None,
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
self.regions[0].blocks.append()
|
||||
|
||||
@property
|
||||
def patterns(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
@property
|
||||
def patterns(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
class testGetParentOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
isolated_from_above: bool = False,
|
||||
op_name: Optional[str] = None,
|
||||
deduplicate: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
isolated_from_above=isolated_from_above,
|
||||
op_name=op_name,
|
||||
deduplicate=deduplicate,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
isolated_from_above: bool = False,
|
||||
op_name: Optional[str] = None,
|
||||
deduplicate: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
isolated_from_above=isolated_from_above,
|
||||
op_name=op_name,
|
||||
deduplicate=deduplicate,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MergeHandlesOp:
|
||||
@@ -130,12 +128,6 @@ class SequenceOp:
|
||||
else None
|
||||
)
|
||||
root_type = root.type if not isinstance(target, Type) else target
|
||||
if not isinstance(failure_propagation_mode, Attribute):
|
||||
failure_propagation_mode_attr = IntegerAttr.get(
|
||||
IntegerType.get_signless(32), failure_propagation_mode._as_int()
|
||||
)
|
||||
else:
|
||||
failure_propagation_mode_attr = failure_propagation_mode
|
||||
|
||||
if extra_bindings is None:
|
||||
extra_bindings = []
|
||||
@@ -152,7 +144,7 @@ class SequenceOp:
|
||||
|
||||
super().__init__(
|
||||
results_=results,
|
||||
failure_propagation_mode=failure_propagation_mode_attr,
|
||||
failure_propagation_mode=failure_propagation_mode,
|
||||
root=root,
|
||||
extra_bindings=extra_bindings,
|
||||
)
|
||||
|
||||
@@ -2,22 +2,6 @@
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class FailurePropagationMode(Enum):
|
||||
"""Propagation mode for silenceable errors."""
|
||||
|
||||
PROPAGATE = 1
|
||||
SUPPRESS = 2
|
||||
|
||||
def _as_int(self):
|
||||
if self is FailurePropagationMode.PROPAGATE:
|
||||
return 1
|
||||
|
||||
assert self is FailurePropagationMode.SUPPRESS
|
||||
return 2
|
||||
|
||||
|
||||
from .._transform_enum_gen import *
|
||||
from .._transform_ops_gen import *
|
||||
from ..._mlir_libs._mlirDialectsTransform import *
|
||||
|
||||
57
mlir/test/mlir-tblgen/enums-python-bindings.td
Normal file
57
mlir/test/mlir-tblgen/enums-python-bindings.td
Normal file
@@ -0,0 +1,57 @@
|
||||
// RUN: mlir-tblgen -gen-python-enum-bindings %s -I %S/../../include | FileCheck %s
|
||||
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
|
||||
// CHECK: Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
// CHECK: from enum import Enum
|
||||
// CHECK: from ._ods_common import _cext as _ods_cext
|
||||
// CHECK: _ods_ir = _ods_cext.ir
|
||||
|
||||
def One : I32EnumAttrCase<"CaseOne", 1, "one">;
|
||||
def Two : I32EnumAttrCase<"CaseTwo", 2, "two">;
|
||||
|
||||
def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two]>;
|
||||
// CHECK: def _register_attribute_builder(kind):
|
||||
// CHECK: def decorator_builder(func):
|
||||
// CHECK: _ods_ir.AttrBuilder.insert(kind, func)
|
||||
// CHECK: return func
|
||||
// CHECK: return decorator_builder
|
||||
|
||||
// CHECK-LABEL: class MyEnum(Enum):
|
||||
// CHECK: """An example 32-bit enum"""
|
||||
|
||||
// CHECK: CASE_ONE = 1
|
||||
// CHECK: CASE_TWO = 2
|
||||
|
||||
// CHECK: def _as_int(self):
|
||||
// CHECK: if self is MyEnum.CASE_ONE:
|
||||
// CHECK: return 1
|
||||
// CHECK: if self is MyEnum.CASE_TWO:
|
||||
// CHECK: return 2
|
||||
// CHECK: assert False, "Unknown MyEnum enum entry."
|
||||
|
||||
def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
|
||||
def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">;
|
||||
|
||||
def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>;
|
||||
// CHECK: @_register_attribute_builder("MyEnum")
|
||||
// CHECK: def _my_enum(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), x._as_int())
|
||||
|
||||
// CHECK-LABEL: class MyEnum64(Enum):
|
||||
// CHECK: """An example 64-bit enum"""
|
||||
|
||||
// CHECK: CASE_ONE64 = 1
|
||||
// CHECK: CASE_TWO64 = 2
|
||||
|
||||
// CHECK: def _as_int(self):
|
||||
// CHECK: if self is MyEnum64.CASE_ONE64:
|
||||
// CHECK: return 1
|
||||
// CHECK: if self is MyEnum64.CASE_TWO64:
|
||||
// CHECK: return 2
|
||||
// CHECK: assert False, "Unknown MyEnum64 enum entry."
|
||||
|
||||
// CHECK: @_register_attribute_builder("MyEnum64")
|
||||
// CHECK: def _my_enum64(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), x._as_int())
|
||||
@@ -14,6 +14,7 @@ add_tablegen(mlir-tblgen MLIR
|
||||
DialectGen.cpp
|
||||
DirectiveCommonGen.cpp
|
||||
EnumsGen.cpp
|
||||
EnumPythonBindingGen.cpp
|
||||
FormatGen.cpp
|
||||
LLVMIRConversionGen.cpp
|
||||
LLVMIRIntrinsicGen.cpp
|
||||
|
||||
130
mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
Normal file
130
mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
Normal file
@@ -0,0 +1,130 @@
|
||||
//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
|
||||
// generate the corresponding Python binding classes.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
/// File header and includes.
|
||||
constexpr const char *fileHeader = R"Py(
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from enum import Enum
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
_ods_ir = _ods_cext.ir
|
||||
|
||||
# Convenience decorator for registering user-friendly Attribute builders.
|
||||
def _register_attribute_builder(kind):
|
||||
def decorator_builder(func):
|
||||
_ods_ir.AttrBuilder.insert(kind, func)
|
||||
return func
|
||||
|
||||
return decorator_builder
|
||||
|
||||
)Py";
|
||||
|
||||
/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
|
||||
static std::string makePythonEnumCaseName(StringRef name) {
|
||||
return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper();
|
||||
}
|
||||
|
||||
/// Emits the Python class for the given enum.
|
||||
static void emitEnumClass(StringRef enumName, StringRef description,
|
||||
ArrayRef<EnumAttrCase> cases, raw_ostream &os) {
|
||||
os << llvm::formatv("class {0}(Enum):\n", enumName);
|
||||
if (!description.empty())
|
||||
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", description);
|
||||
os << "\n";
|
||||
|
||||
for (const EnumAttrCase &enumCase : cases) {
|
||||
os << llvm::formatv(" {0} = {1}\n",
|
||||
makePythonEnumCaseName(enumCase.getSymbol()),
|
||||
enumCase.getValue());
|
||||
}
|
||||
|
||||
os << "\n";
|
||||
os << llvm::formatv(" def _as_int(self):\n");
|
||||
for (const EnumAttrCase &enumCase : cases) {
|
||||
os << llvm::formatv(" if self is {0}.{1}:\n", enumName,
|
||||
makePythonEnumCaseName(enumCase.getSymbol()));
|
||||
os << llvm::formatv(" return {0}\n", enumCase.getValue());
|
||||
}
|
||||
os << llvm::formatv(" assert False, \"Unknown {0} enum entry.\"\n\n\n",
|
||||
enumName);
|
||||
}
|
||||
|
||||
/// Attempts to extract the bitwidth B from string "uintB_t" describing the
|
||||
/// type. This bitwidth information is not readily available in ODS. Returns
|
||||
/// `false` on success, `true` on failure.
|
||||
static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
|
||||
if (!uintType.consume_front("uint"))
|
||||
return true;
|
||||
if (!uintType.consume_back("_t"))
|
||||
return true;
|
||||
return uintType.getAsInteger(/*Radix=*/10, bitwidth);
|
||||
}
|
||||
|
||||
/// Emits an attribute builder for the given enum attribute to support automatic
|
||||
/// conversion between enum values and attributes in Python. Returns
|
||||
/// `false` on success, `true` on failure.
|
||||
static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
|
||||
int64_t bitwidth;
|
||||
if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
|
||||
llvm::errs() << "failed to identify bitwidth of "
|
||||
<< enumAttr.getUnderlyingType();
|
||||
return true;
|
||||
}
|
||||
|
||||
os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n",
|
||||
enumAttr.getAttrDefName());
|
||||
os << llvm::formatv(
|
||||
"def _{0}(x, context):\n",
|
||||
llvm::convertToSnakeFromCamelCase(enumAttr.getAttrDefName()));
|
||||
os << llvm::formatv(
|
||||
" return "
|
||||
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
|
||||
"context=context), x._as_int())\n\n",
|
||||
bitwidth);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Emits Python bindings for all enums in the record keeper. Returns
|
||||
/// `false` on success, `true` on failure.
|
||||
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
os << fileHeader;
|
||||
std::vector<llvm::Record *> defs =
|
||||
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
|
||||
for (const llvm::Record *def : defs) {
|
||||
EnumAttr enumAttr(*def);
|
||||
if (enumAttr.isBitEnum()) {
|
||||
llvm::errs() << "bit enums not supported\n";
|
||||
return true;
|
||||
}
|
||||
emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(),
|
||||
enumAttr.getAllCases(), os);
|
||||
emitAttributeBuilder(enumAttr, os);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Registers the enum utility generator to mlir-tblgen.
|
||||
static mlir::GenRegistration
|
||||
genPythonEnumBindings("gen-python-enum-bindings",
|
||||
"Generate Python bindings for enum attributes",
|
||||
&emitPythonEnums);
|
||||
@@ -732,6 +732,25 @@ filegroup(
|
||||
# Transform dialect and extensions.
|
||||
##---------------------------------------------------------------------------##
|
||||
|
||||
|
||||
gentbl_filegroup(
|
||||
name = "TransformEnumPyGen",
|
||||
tbl_outs = [
|
||||
(
|
||||
["-gen-python-enum-bindings"],
|
||||
"mlir/dialects/_transform_enum_gen.py",
|
||||
),
|
||||
],
|
||||
tblgen = "//mlir:mlir-tblgen",
|
||||
td_file = "mlir/dialects/TransformOps.td",
|
||||
deps = [
|
||||
"//mlir:CallInterfacesTdFiles",
|
||||
"//mlir:FunctionInterfacesTdFiles",
|
||||
"//mlir:OpBaseTdFiles",
|
||||
"//mlir:TransformDialectTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl_filegroup(
|
||||
name = "TransformOpsPyGen",
|
||||
tbl_outs = [
|
||||
@@ -898,6 +917,7 @@ filegroup(
|
||||
":MemRefTransformOpsPyGen",
|
||||
":PDLTransformOpsPyGen",
|
||||
":StructuredTransformOpsPyGen",
|
||||
":TransformEnumPyGen",
|
||||
":TransformOpsPyGen",
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user