mirror of
https://github.com/intel/llvm.git
synced 2026-02-07 07:39:11 +08:00
[mlir][transform][gpu][python] Add MapForallToBlocks mix-in.
This patch adds a mix-in class for MapForallToBlocks with overloaded constructors. This makes it optional to provide the return type of the op, which is defaulte to `AnyOpType`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155717
This commit is contained in:
@@ -148,6 +148,7 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/GPUTransformOps.td
|
||||
SOURCES
|
||||
dialects/_gpu_transform_ops_ext.py
|
||||
dialects/transform/gpu.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME gpu_transform)
|
||||
|
||||
69
mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
Normal file
69
mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union, overload
|
||||
|
||||
|
||||
class MapForallToBlocks:
|
||||
"""Specialization for MapForallToBlocks class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
grid_dims: Optional[Sequence[int]] = None,
|
||||
generate_gpu_launch: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
grid_dims: Optional[Sequence[int]] = None,
|
||||
generate_gpu_launch: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type_or_target: Union[Operation, OpView, Type, Value],
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
grid_dims: Optional[Sequence[int]] = None,
|
||||
generate_gpu_launch: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(result_type_or_target, Type):
|
||||
result_type = result_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
result_type = transform.AnyOpType.get()
|
||||
target = result_type_or_target
|
||||
|
||||
if grid_dims is not None and not isinstance(grid_dims, ArrayAttr):
|
||||
grid_dims = DenseI64ArrayAttr.get(grid_dims)
|
||||
|
||||
super().__init__(
|
||||
result_type,
|
||||
target,
|
||||
grid_dims=grid_dims,
|
||||
generate_gpu_launch=generate_gpu_launch,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
Reference in New Issue
Block a user