From f3dab16dc721feee669947d4cb95e88ab90c78f5 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 17 Nov 2020 14:17:22 +0100 Subject: [PATCH] [mlir] Add a _get_default_loc_context utility to Python bindings This utility function is helpful for dialect-specific builders that need to access the context through location, and the location itself may be either provided as an argument or expected to be recovered from the implicit location stack. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D91623 --- mlir/lib/Bindings/Python/mlir/dialects/__init__.py | 11 +++++++++++ mlir/test/mlir-tblgen/op-python-bindings.td | 8 ++++---- mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 6 +++--- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py index 0aceff1caf3f..56398ea5b64a 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py @@ -41,3 +41,14 @@ def _equally_sized_accessor(elements, n_variadic, n_preceding_simple, elements_per_group = total_variadic_length // n_variadic start = n_preceding_simple + n_preceding_variadic * elements_per_group return start, elements_per_group + +def _get_default_loc_context(location = None): + """ + Returns a context in which the defaulted location is created. If the location + is None, takes the current location from the stack, raises ValueError if there + is no location on the stack. + """ + if location is None: + # Location.current raises ValueError if there is no current location. + return _cext.ir.Location.current.context + return location.context diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 04d798a42641..5a27cc2f0cd5 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -31,7 +31,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", // CHECK: if variadic2 is not None: operands.append(variadic2) // CHECK: operand_segment_sizes.append(0 if variadic2 is None else 1) // CHECK: attributes["operand_segment_sizes"] = _ir.DenseElementsAttr.get(operand_segment_sizes, - // CHECK: context=Location.current.context if loc is None else loc.context) + // CHECK: context=_get_default_loc_context(loc)) // CHECK: super().__init__(_ir.Operation.create( // CHECK: "test.attr_sized_operands", attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) @@ -77,7 +77,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results", // CHECK: if variadic2 is not None: results.append(variadic2) // CHECK: result_segment_sizes.append(0 if variadic2 is None else 1) // CHECK: attributes["result_segment_sizes"] = _ir.DenseElementsAttr.get(result_segment_sizes, - // CHECK: context=Location.current.context if loc is None else loc.context) + // CHECK: context=_get_default_loc_context(loc)) // CHECK: super().__init__(_ir.Operation.create( // CHECK: "test.attr_sized_results", attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) @@ -118,7 +118,7 @@ def AttributedOp : TestOp<"attributed_op"> { // CHECK: attributes["i32attr"] = i32attr // CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ir.UnitAttr.get( - // CHECK: _ir.Location.current.context if loc is None else loc.context) + // CHECK: _get_default_loc_context(loc)) // CHECK: attributes["in"] = in_ // CHECK: super().__init__(_ir.Operation.create( // CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results, @@ -156,7 +156,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { // CHECK: operands.append(_gen_arg_0) // CHECK: operands.append(_gen_arg_2) // CHECK: if bool(in_): attributes["in"] = _ir.UnitAttr.get( - // CHECK: _ir.Location.current.context if loc is None else loc.context) + // CHECK: _get_default_loc_context(loc)) // CHECK: if is_ is not None: attributes["is"] = is_ // CHECK: super().__init__(_ir.Operation.create( // CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results, diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 2a3ce5500133..e970d305fd8a 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -28,7 +28,7 @@ constexpr const char *fileHeader = R"Py( import array from . import _cext -from . import _segmented_accessor, _equally_sized_accessor +from . import _segmented_accessor, _equally_sized_accessor, _get_default_loc_context _ir = _cext.ir )Py"; @@ -410,7 +410,7 @@ constexpr const char *segmentDeclarationTemplate = /// Template for attaching segment sizes to the attribute list. constexpr const char *segmentAttributeTemplate = R"Py(attributes["{0}_segment_sizes"] = _ir.DenseElementsAttr.get({0}_segment_sizes, - context=Location.current.context if loc is None else loc.context))Py"; + context=_get_default_loc_context(loc)))Py"; /// Template for appending the unit size to the segment sizes. /// {0} is either 'operand' or 'result'; @@ -443,7 +443,7 @@ constexpr const char *initOptionalAttributeTemplate = constexpr const char *initUnitAttributeTemplate = R"Py(if bool({1}): attributes["{0}"] = _ir.UnitAttr.get( - _ir.Location.current.context if loc is None else loc.context))Py"; + _get_default_loc_context(loc)))Py"; /// Populates `builderArgs` with the Python-compatible names of builder function /// arguments, first the results, then the intermixed attributes and operands in