mirror of
https://github.com/intel/llvm.git
synced 2026-01-22 23:49:22 +08:00
[mlir][transform][gpu][python] Add MapForallToBlocks mix-in.
This patch adds a mix-in class for MapForallToBlocks with overloaded constructors. This makes it optional to provide the return type of the op, which is defaulte to `AnyOpType`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155717
This commit is contained in:
@@ -148,6 +148,7 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/GPUTransformOps.td
|
||||
SOURCES
|
||||
dialects/_gpu_transform_ops_ext.py
|
||||
dialects/transform/gpu.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME gpu_transform)
|
||||
|
||||
69
mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
Normal file
69
mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union, overload
|
||||
|
||||
|
||||
class MapForallToBlocks:
|
||||
"""Specialization for MapForallToBlocks class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
grid_dims: Optional[Sequence[int]] = None,
|
||||
generate_gpu_launch: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
grid_dims: Optional[Sequence[int]] = None,
|
||||
generate_gpu_launch: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type_or_target: Union[Operation, OpView, Type, Value],
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
grid_dims: Optional[Sequence[int]] = None,
|
||||
generate_gpu_launch: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(result_type_or_target, Type):
|
||||
result_type = result_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
result_type = transform.AnyOpType.get()
|
||||
target = result_type_or_target
|
||||
|
||||
if grid_dims is not None and not isinstance(grid_dims, ArrayAttr):
|
||||
grid_dims = DenseI64ArrayAttr.get(grid_dims)
|
||||
|
||||
super().__init__(
|
||||
result_type,
|
||||
target,
|
||||
grid_dims=grid_dims,
|
||||
generate_gpu_launch=generate_gpu_launch,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
59
mlir/test/python/dialects/transform_gpu_ext.py
Normal file
59
mlir/test/python/dialects/transform_gpu_ext.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import transform
|
||||
from mlir.dialects.transform import gpu
|
||||
|
||||
|
||||
def run(f):
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
@run
|
||||
def testMapForallToBlocksCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
gpu.MapForallToBlocks(sequence.bodyTarget)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testMapForallToBlocksCompact
|
||||
# CHECK: = transform.gpu.map_forall_to_blocks
|
||||
# CHECK-NOT: grid_dims
|
||||
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
|
||||
# CHECK-NOT: grid_dims
|
||||
|
||||
|
||||
@run
|
||||
def testMapForallToBlocksTyped():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
gpu.MapForallToBlocks(
|
||||
transform.OperationType.get("test.dummy"), sequence.bodyTarget
|
||||
)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testMapForallToBlocksTyped
|
||||
# CHECK: = transform.gpu.map_forall_to_blocks
|
||||
# CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">
|
||||
|
||||
|
||||
@run
|
||||
def testMapForallToBlocksGridDims():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
gpu.MapForallToBlocks(sequence.bodyTarget, grid_dims=[4, 2])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testMapForallToBlocksGridDims
|
||||
# CHECK: = transform.gpu.map_forall_to_blocks
|
||||
# CHECK-SAME: grid_dims = [4, 2]
|
||||
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
|
||||
Reference in New Issue
Block a user