mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 09:31:59 +08:00
[mlir][python] Provide more convenient constructors for std.CallOp
The new constructor relies on type-based dynamic dispatch and allows one to construct call operations given an object representing a FuncOp or its name as a string, as opposed to requiring an explicitly constructed attribute. Depends On D110947 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110948
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import inspect
|
||||
|
||||
@@ -82,8 +82,8 @@ class FuncOp:
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.attributes["sym_name"]
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
@@ -104,11 +104,15 @@ class FuncOp:
|
||||
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return self.attributes[ARGUMENT_ATTRIBUTE_NAME]
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context)
|
||||
|
||||
@property
|
||||
def arguments(self):
|
||||
|
||||
@@ -69,3 +69,73 @@ class ConstantOp:
|
||||
return FloatAttr(self.value).value
|
||||
else:
|
||||
raise ValueError("only integer and float constants have literal values")
|
||||
|
||||
|
||||
class CallOp:
|
||||
"""Specialization for the call op class."""
|
||||
|
||||
def __init__(self,
|
||||
calleeOrResults: Union[FuncOp, List[Type]],
|
||||
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
|
||||
arguments: Optional[List] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""Creates an call operation.
|
||||
|
||||
The constructor accepts three different forms:
|
||||
|
||||
1. A function op to be called followed by a list of arguments.
|
||||
2. A list of result types, followed by the name of the function to be
|
||||
called as string, following by a list of arguments.
|
||||
3. A list of result types, followed by the name of the function to be
|
||||
called as symbol reference attribute, followed by a list of arguments.
|
||||
|
||||
For example
|
||||
|
||||
f = builtin.FuncOp("foo", ...)
|
||||
std.CallOp(f, [args])
|
||||
std.CallOp([result_types], "foo", [args])
|
||||
|
||||
In all cases, the location and insertion point may be specified as keyword
|
||||
arguments if not provided by the surrounding context managers.
|
||||
"""
|
||||
|
||||
# TODO: consider supporting constructor "overloads", e.g., through a custom
|
||||
# or pybind-provided metaclass.
|
||||
if isinstance(calleeOrResults, FuncOp):
|
||||
if not isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function, expected " +
|
||||
"the second argument to be a list of call arguments, " +
|
||||
f"got {type(argumentsOrCallee)}")
|
||||
if arguments is not None:
|
||||
raise ValueError("unexpected third argument when constructing a call" +
|
||||
"to a function")
|
||||
|
||||
super().__init__(
|
||||
calleeOrResults.type.results,
|
||||
FlatSymbolRefAttr.get(
|
||||
calleeOrResults.name.value,
|
||||
context=_get_default_loc_context(loc)),
|
||||
argumentsOrCallee,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
return
|
||||
|
||||
if isinstance(argumentsOrCallee, list):
|
||||
raise ValueError("when constructing a call to a function by name, " +
|
||||
"expected the second argument to be a string or a " +
|
||||
f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
|
||||
|
||||
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
|
||||
super().__init__(
|
||||
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
|
||||
elif isinstance(argumentsOrCallee, str):
|
||||
super().__init__(
|
||||
calleeOrResults,
|
||||
FlatSymbolRefAttr.get(
|
||||
argumentsOrCallee, context=_get_default_loc_context(loc)),
|
||||
arguments,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
@@ -171,7 +171,7 @@ def testFuncArgumentAccess():
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
func = builtin.FuncOp("some_func", ([f32, f32], [f64, f64]))
|
||||
func = builtin.FuncOp("some_func", ([f32, f32], [f32, f32]))
|
||||
with InsertionPoint(func.add_entry_block()):
|
||||
std.ReturnOp(func.arguments)
|
||||
func.arg_attrs = ArrayAttr.get([
|
||||
@@ -186,6 +186,14 @@ def testFuncArgumentAccess():
|
||||
DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
|
||||
])
|
||||
|
||||
other = builtin.FuncOp("other_func", ([f32, f32], []))
|
||||
with InsertionPoint(other.add_entry_block()):
|
||||
std.ReturnOp([])
|
||||
other.arg_attrs = [
|
||||
DictAttr.get({"foo": StringAttr.get("qux")}),
|
||||
DictAttr.get()
|
||||
]
|
||||
|
||||
# CHECK: [{baz, foo = "bar"}, {qux = []}]
|
||||
print(func.arg_attrs)
|
||||
|
||||
@@ -195,7 +203,11 @@ def testFuncArgumentAccess():
|
||||
# CHECK: func @some_func(
|
||||
# CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
|
||||
# CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
|
||||
# CHECK: f64 {res1 = 4.200000e+01 : f32},
|
||||
# CHECK: f64 {res2 = 2.560000e+02 : f64})
|
||||
# CHECK: f32 {res1 = 4.200000e+01 : f32},
|
||||
# CHECK: f32 {res2 = 2.560000e+02 : f64})
|
||||
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
|
||||
#
|
||||
# CHECK: func @other_func(
|
||||
# CHECK: %{{.*}}: f32 {foo = "qux"},
|
||||
# CHECK: %{{.*}}: f32)
|
||||
print(module)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import builtin
|
||||
from mlir.dialects import std
|
||||
|
||||
|
||||
@@ -62,3 +63,27 @@ def testConstantIndexOp():
|
||||
print(c1.literal_value)
|
||||
|
||||
# CHECK: = constant 10 : index
|
||||
|
||||
# CHECK-LABEL: TEST: testFunctionCalls
|
||||
@constructAndPrintInModule
|
||||
def testFunctionCalls():
|
||||
foo = builtin.FuncOp("foo", ([], []))
|
||||
bar = builtin.FuncOp("bar", ([], [IndexType.get()]))
|
||||
qux = builtin.FuncOp("qux", ([], [F32Type.get()]))
|
||||
|
||||
with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()):
|
||||
std.CallOp(foo, [])
|
||||
std.CallOp([IndexType.get()], "bar", [])
|
||||
std.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
|
||||
std.ReturnOp([])
|
||||
|
||||
# CHECK: func @foo()
|
||||
# CHECK: func @bar() -> index
|
||||
# CHECK: func @qux() -> f32
|
||||
# CHECK: func @caller() {
|
||||
# CHECK: call @foo() : () -> ()
|
||||
# CHECK: %0 = call @bar() : () -> index
|
||||
# CHECK: %1 = call @qux() : () -> f32
|
||||
# CHECK: return
|
||||
# CHECK: }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user