mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
[MLIR][SCF] Add dedicated Python bindings for ForallOp (#149416)
This patch specializes the Python bindings for ForallOp and InParallelOp, similar to the existing one for ForOp. These bindings create the regions and blocks properly and expose some additional helpers.
This commit is contained in:
committed by
GitHub
parent
68fd102598
commit
fef4238288
@@ -17,7 +17,7 @@ try:
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
@@ -71,6 +71,123 @@ class ForOp(ForOp):
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
def _dispatch_index_op_fold_results(
|
||||
ofrs: Sequence[Union[Operation, OpView, Value, int]],
|
||||
) -> Tuple[List[Value], List[int]]:
|
||||
"""`mlir::dispatchIndexOpFoldResults`"""
|
||||
dynamic_vals = []
|
||||
static_vals = []
|
||||
for ofr in ofrs:
|
||||
if isinstance(ofr, (Operation, OpView, Value)):
|
||||
val = _get_op_result_or_value(ofr)
|
||||
dynamic_vals.append(val)
|
||||
static_vals.append(ShapedType.get_dynamic_size())
|
||||
else:
|
||||
static_vals.append(ofr)
|
||||
return dynamic_vals, static_vals
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ForallOp(ForallOp):
|
||||
"""Specialization for the SCF forall op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower_bounds: Sequence[Union[Operation, OpView, Value, int]],
|
||||
upper_bounds: Sequence[Union[Operation, OpView, Value, int]],
|
||||
steps: Sequence[Union[Value, int]],
|
||||
shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
|
||||
*,
|
||||
mapping=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an SCF `forall` operation.
|
||||
|
||||
- `lower_bounds` are the values to use as lower bounds of the loop.
|
||||
- `upper_bounds` are the values to use as upper bounds of the loop.
|
||||
- `steps` are the values to use as loop steps.
|
||||
- `shared_outs` is a list of additional loop-carried arguments or an operation
|
||||
producing them as results.
|
||||
"""
|
||||
assert (
|
||||
len(lower_bounds) == len(upper_bounds) == len(steps)
|
||||
), "Mismatch in length of lower bounds, upper bounds, and steps"
|
||||
if shared_outs is None:
|
||||
shared_outs = []
|
||||
shared_outs = _get_op_results_or_values(shared_outs)
|
||||
|
||||
dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds)
|
||||
dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds)
|
||||
dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps)
|
||||
|
||||
results = [arg.type for arg in shared_outs]
|
||||
super().__init__(
|
||||
results,
|
||||
dynamic_lbs,
|
||||
dynamic_ubs,
|
||||
dynamic_steps,
|
||||
static_lbs,
|
||||
static_ubs,
|
||||
static_steps,
|
||||
shared_outs,
|
||||
mapping=mapping,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
rank = len(static_lbs)
|
||||
iv_types = [IndexType.get()] * rank
|
||||
self.regions[0].blocks.append(*iv_types, *results)
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
"""Returns the body (block) of the loop."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
"""Returns the number of induction variables the loop has."""
|
||||
return len(self.staticLowerBound)
|
||||
|
||||
@property
|
||||
def induction_variables(self) -> BlockArgumentList:
|
||||
"""Returns the induction variables usable within the loop."""
|
||||
return self.body.arguments[: self.rank]
|
||||
|
||||
@property
|
||||
def inner_iter_args(self) -> BlockArgumentList:
|
||||
"""Returns the loop-carried arguments usable within the loop.
|
||||
|
||||
To obtain the loop-carried operands, use `iter_args`.
|
||||
"""
|
||||
return self.body.arguments[self.rank :]
|
||||
|
||||
def terminator(self) -> InParallelOp:
|
||||
"""
|
||||
Returns the loop terminator if it exists.
|
||||
Otherwise, creates a new one.
|
||||
"""
|
||||
ops = self.body.operations
|
||||
with InsertionPoint(self.body):
|
||||
if not ops:
|
||||
return InParallelOp()
|
||||
last = ops[len(ops) - 1]
|
||||
return last if isinstance(last, InParallelOp) else InParallelOp()
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class InParallelOp(InParallelOp):
|
||||
"""Specialization of the SCF forall.in_parallel op class."""
|
||||
|
||||
def __init__(self, loc=None, ip=None):
|
||||
super().__init__(loc=loc, ip=ip)
|
||||
self.region.blocks.append()
|
||||
|
||||
@property
|
||||
def block(self) -> Block:
|
||||
return self.region.blocks[0]
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class IfOp(IfOp):
|
||||
"""Specialization for the SCF if op class."""
|
||||
|
||||
@@ -18,6 +18,26 @@ def constructAndPrintInModule(f):
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSimpleForall
|
||||
# CHECK: scf.forall (%[[IV0:.*]], %[[IV1:.*]]) in (4, 8) shared_outs(%[[BOUND_ARG:.*]] = %{{.*}}) -> (tensor<4x8xf32>)
|
||||
# CHECK: arith.addi %[[IV0]], %[[IV1]]
|
||||
# CHECK: scf.forall.in_parallel
|
||||
@constructAndPrintInModule
|
||||
def testSimpleForall():
|
||||
f32 = F32Type.get()
|
||||
tensor_type = RankedTensorType.get([4, 8], f32)
|
||||
|
||||
@func.FuncOp.from_py_func(tensor_type)
|
||||
def forall_loop(tensor):
|
||||
loop = scf.ForallOp([0, 0], [4, 8], [1, 1], [tensor])
|
||||
with InsertionPoint(loop.body):
|
||||
i, j = loop.induction_variables
|
||||
arith.addi(i, j)
|
||||
loop.terminator()
|
||||
# The verifier will check that the regions have been created properly.
|
||||
assert loop.verify()
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSimpleLoop
|
||||
@constructAndPrintInModule
|
||||
def testSimpleLoop():
|
||||
|
||||
Reference in New Issue
Block a user