[mlir][OpDSL] Add support for basic rank polymorphism.

Previously, OpDSL did not support rank polymorphism, which required a separate implementation of linalg.fill. This revision extends OpDSL to support rank polymorphism for a limited class of operations that access only scalars and tensors of rank zero. At operation instantiation time, it scales these scalar computations to multi-dimensional pointwise computations by replacing the empty indexing maps with identity index maps. The revision does not change the DSL itself, instead it adapts the Python emitter and the YAML generator to generate different indexing maps and and iterators depending on the rank of the first output.

Additionally, the revision introduces a `linalg.fill_tensor` operation that in a future revision shall replace the current handwritten `linalg.fill` operation. `linalg.fill_tensor` is thus only temporarily available and will be renamed to `linalg.fill`.

Reviewed By: nicolasvasilache, stellaraccident

Differential Revision: https://reviews.llvm.org/D119003
This commit is contained in:
gysit
2022-02-11 08:20:37 +00:00
parent b21f497a78
commit a3655de2c8
10 changed files with 473 additions and 120 deletions

View File

@@ -14,6 +14,7 @@ from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, g
from .scalar_expr import *
from .config import *
from .comprehension import *
import numpy as np
__all__ = [
@@ -132,6 +133,25 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
# An operation that accesses only scalars and scalar/rank zero tensors is
# rank polymorhpic. We implement rank polymorphism by generating different
# indexing maps and iterators that match the rank of the first output tensor.
# An operation is rank polymorphic if the iteration domain has rank zero.
if not iterator_types_attr:
rank = ShapedType(outs[0].type).rank
iterator_types_attr = ArrayAttr.get([StringAttr.get("parallel")] * rank)
scalar_map = AffineMap.get(rank, 0, [])
tensor_map = AffineMap.get_identity(rank)
indexing_maps = []
for arg_def in all_arg_defs:
if arg_def.operand_def.kind == OperandKind.Scalar:
indexing_maps.append(scalar_map)
if (arg_def.operand_def.kind == OperandKind.InputTensor or
arg_def.operand_def.kind == OperandKind.OutputTensor):
indexing_maps.append(tensor_map)
indexing_maps_attr = ArrayAttr.get(
[AffineMapAttr.get(am) for am in indexing_maps])
generic_op = linalg.GenericOp(
result_tensors=result_types,
inputs=ins,
@@ -172,19 +192,13 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
raise NotImplementedError(
f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
# Set the index attributes used to compute the indexing maps.
named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
linalg.fill_builtin_region(named_op.operation)
# Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
# attribute that the non-yaml path does not. The non-yaml path hardcodes the
# indexing_maps in C++ directly.
named_op.operation.attributes[
"linalg.memoized_indexing_maps"] = indexing_maps_attr
# iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
# Additionally set all named attributes.
for name, value in index_attributes.items():
named_op.operation.attributes[name] = value
linalg.fill_builtin_region(named_op.operation)
if len(result_types) == 1:
return named_op.result
else:

View File

@@ -627,6 +627,17 @@ def pooling_ndhwc_min(
D.ow * S.SW + D.kw * S.DW, D.c]))
@linalg_structured_op
def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
"""Fills the output tensor with the given value.
Works for arbitrary ranked output tensors since the operation performs scalar
accesses only and is thus rank polymorphic. Numeric casting is performed on
the value operand, promoting it to the same data type as the output.
"""
O[None] = TypeFn.cast(U, value)
@linalg_structured_op
def fill_rng_2d(
min=ScalarDef(F64),