From a3655de2c81fc959590c109d81a010fc8e09c48e Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 11 Feb 2022 08:20:37 +0000 Subject: [PATCH] [mlir][OpDSL] Add support for basic rank polymorphism. Previously, OpDSL did not support rank polymorphism, which required a separate implementation of linalg.fill. This revision extends OpDSL to support rank polymorphism for a limited class of operations that access only scalars and tensors of rank zero. At operation instantiation time, it scales these scalar computations to multi-dimensional pointwise computations by replacing the empty indexing maps with identity index maps. The revision does not change the DSL itself, instead it adapts the Python emitter and the YAML generator to generate different indexing maps and and iterators depending on the rank of the first output. Additionally, the revision introduces a `linalg.fill_tensor` operation that in a future revision shall replace the current handwritten `linalg.fill` operation. `linalg.fill_tensor` is thus only temporarily available and will be renamed to `linalg.fill`. Reviewed By: nicolasvasilache, stellaraccident Differential Revision: https://reviews.llvm.org/D119003 --- mlir/docs/Dialects/Linalg/OpDSL.md | 30 ++- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 36 +++ .../dialects/linalg/opdsl/lang/emitter.py | 32 ++- .../linalg/opdsl/ops/core_named_ops.py | 11 + .../generalize-named-polymorphic-ops.mlir | 29 +++ .../test-linalg-ods-yaml-gen.yaml | 52 ++++ .../python/dialects/linalg/opdsl/emit_fill.py | 46 ++++ mlir/test/python/dialects/linalg/ops.py | 2 +- .../integration/dialects/linalg/opsrun.py | 125 +++++++++- .../mlir-linalg-ods-yaml-gen.cpp | 230 ++++++++++-------- 10 files changed, 473 insertions(+), 120 deletions(-) create mode 100644 mlir/test/python/dialects/linalg/opdsl/emit_fill.py diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md index 79f22a247bb2..deec3eae0fd2 100644 --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -102,7 +102,7 @@ bound to a `TensorDef` as demonstrated by the matmul example. All parameters appear in the parameter list of the operation: ```python -fill(val, in_tensor, outs=[out_tensor]) +copy_and_scale(val, in_tensor, outs=[out_tensor]) ``` ## Attributes @@ -251,3 +251,31 @@ The following examples illustrate the lowering of signed and unsigned functions: Not all functions are applicable for all numeric types, and on mismatch, op verification will fail. + +## Pointwise Computations + +Pointwise computations are expressible in a rank polymorphic form that supports +arbitrary ranked operands - all of them need to have the same rank - with a +single operation definition. + +An example for a rank polymorphic operation is `fill`: + +```python +@linalg_structured_op +def fill(value=ScalarDef(T1), + O=TensorDef(U, output=True)): + O[None] = TypeFn.cast(U, value) +``` + +The operation sets the elements of the output tensor `O` to `value`. All +operands are either scalars or rank zero tensors that are accessed using the +index `None`. The operation thus performs a scalar computation that trivially +extends to a multi-dimensional pointwise computation. As a result, we may use +`fill` with arbitrary ranked output tensors: + +```python +tensor_2d = linalg.InitTensorOp([4, 8], f32) +tensor_3d = linalg.InitTensorOp([4, 8, 16], f32) +fill(value, outs=[tensor_2d]) +fill(value, outs=[tensor_3d]) +``` diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index dc5e5862e83c..69a4cc407b9d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2522,6 +2522,42 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: I --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: fill_tensor + cpp_class_name: FillTensorOp + doc: |- + Fills the output tensor with the given value. + + Works for arbitrary ranked output tensors since the operation performs scalar + accesses only and is thus rank polymorphic. Numeric casting is performed on + the value operand, promoting it to the same data type as the output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: value + usage: InputOperand + type_var: T1 + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + type_fn: + fn_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: value +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d cpp_class_name: FillRng2DOp diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 22568c8b6748..643bcaa5c2f0 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -14,6 +14,7 @@ from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, g from .scalar_expr import * from .config import * +from .comprehension import * import numpy as np __all__ = [ @@ -132,6 +133,25 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) + # An operation that accesses only scalars and scalar/rank zero tensors is + # rank polymorhpic. We implement rank polymorphism by generating different + # indexing maps and iterators that match the rank of the first output tensor. + # An operation is rank polymorphic if the iteration domain has rank zero. + if not iterator_types_attr: + rank = ShapedType(outs[0].type).rank + iterator_types_attr = ArrayAttr.get([StringAttr.get("parallel")] * rank) + scalar_map = AffineMap.get(rank, 0, []) + tensor_map = AffineMap.get_identity(rank) + indexing_maps = [] + for arg_def in all_arg_defs: + if arg_def.operand_def.kind == OperandKind.Scalar: + indexing_maps.append(scalar_map) + if (arg_def.operand_def.kind == OperandKind.InputTensor or + arg_def.operand_def.kind == OperandKind.OutputTensor): + indexing_maps.append(tensor_map) + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps]) + generic_op = linalg.GenericOp( result_tensors=result_types, inputs=ins, @@ -172,19 +192,13 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") + # Set the index attributes used to compute the indexing maps. named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - linalg.fill_builtin_region(named_op.operation) - # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps - # attribute that the non-yaml path does not. The non-yaml path hardcodes the - # indexing_maps in C++ directly. - named_op.operation.attributes[ - "linalg.memoized_indexing_maps"] = indexing_maps_attr - # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. - - # Additionally set all named attributes. for name, value in index_attributes.items(): named_op.operation.attributes[name] = value + linalg.fill_builtin_region(named_op.operation) + if len(result_types) == 1: return named_op.result else: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index d3651bd766fe..80a8fb6ccf09 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -627,6 +627,17 @@ def pooling_ndhwc_min( D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): + """Fills the output tensor with the given value. + + Works for arbitrary ranked output tensors since the operation performs scalar + accesses only and is thus rank polymorphic. Numeric casting is performed on + the value operand, promoting it to the same data type as the output. + """ + O[None] = TypeFn.cast(U, value) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir index b01191184b05..e5a7e74fc582 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -207,6 +207,35 @@ func @generalize_pooling_nhwc_sum_i32(%input : tensor<1x4x16x1xi32>, %shape: ten // ----- +func @generalize_fill_0d(%value: f64, %O: tensor) -> tensor { + %0 = linalg.fill_tensor ins(%value: f64) outs(%O : tensor) -> tensor + return %0: tensor +} + +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> + +// CHECK-LABEL: @generalize_fill_0d +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] +// CHECK-SAME: iterator_types = [] + +// ----- + +func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) { + linalg.fill_tensor ins(%value: f64) outs(%O : memref<16x32xf32>) + return +} + +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @generalize_fill +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] + +// ----- + func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> { %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml index 3634f4f83dd4..ee36510aaf00 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -175,3 +175,55 @@ structured_op: !LinalgStructuredOpConfig # IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 && # IMPL: yields.push_back(block.getArgument(0)); + +# @linalg_structured_op +# def test3(value=ScalarDef(T1), +# O=TensorDef(U, output=True)): +# """Title. + +# Detailed description. +# """ +# O[None] = TypeFn.cast(U, value) + +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: test3 + cpp_class_name: Test3Op + doc: |- + Title. + + Detailed description. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: value + usage: InputOperand + type_var: T1 + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + type_fn: + fn_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: value + +# IMPL: Test3Op::iterator_types() { +# IMPL-NEXT: int64_t rank = getRank(getOutputOperand(0)); + +# IMPL: Test3Op::indexing_maps() { +# IMPL-NEXT: MLIRContext *context = getContext(); +# IMPL-NEXT: AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context); +# IMPL-NEXT: AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py new file mode 100644 index 000000000000..75524691a487 --- /dev/null +++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py @@ -0,0 +1,46 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import builtin +from mlir.dialects import linalg +from mlir.dialects import std + +from mlir.dialects.linalg.opdsl.lang import * + +T1 = TV.T1 +T2 = TV.T2 + + +@linalg_structured_op +def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)): + O[None] = TypeFn.cast(U, value) + + +with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + # Fill indexing maps. + # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> + # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> + # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> + + # CHECK-LABEL: @test_fill_0d + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]] + # CHECK-SAME: iterator_types = [] + @builtin.FuncOp.from_py_func(f32, RankedTensorType.get([], f32)) + def test_fill_0d(value, init_result): + return fill_poly(value, outs=[init_result]) + + # CHECK-LABEL: @test_fill_2d + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel"] + @builtin.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32)) + def test_fill_2d(value, init_result): + return fill_poly(value, outs=[init_result]) + + +print(module) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 4f9f138683b8..ba57a131f7f3 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -126,7 +126,7 @@ def testNamedStructuredOpGenericForm(): # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () - # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : + # CHECK-NEXT: operand_segment_sizes = dense<[2, 1]> : vector<2xi32> # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return linalg.matmul(lhs, rhs, outs=[init_result.result]) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index b75de1208550..5be00f8df333 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -42,13 +42,42 @@ func @main() -> f32 attributes {llvm.emit_c_interface} { """ fill_boiler = """ +func @main() -> i32 attributes {llvm.emit_c_interface} { + %O0 = memref.alloc() : memref + %O1 = memref.alloc() : memref<16xi32> + %O2 = memref.alloc() : memref<4x16xi32> + + %val0 = arith.constant 1.0 : f32 + %val1 = arith.constant 2.0 : f32 + %val2 = arith.constant 3.0 : f32 + + call @fill_0d_on_buffers(%val0, %O0) : (f32, memref) -> () + call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> () + call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> () + + %c0 = arith.constant 0 : index + %res0 = memref.load %O0[] : memref + %c8 = arith.constant 8 : index + %res1 = memref.load %O1[%c8] : memref<16xi32> + %c2 = arith.constant 2 : index + %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32> + + %0 = arith.addi %res0, %res1 : i32 + %1 = arith.addi %0, %res2 : i32 + + // TODO: FFI-based solution to allow testing and printing with python code. + return %1 : i32 +} +""" + +fill_rng_boiler = """ func @main() -> i32 attributes {llvm.emit_c_interface} { %O = memref.alloc() : memref<4x16xi32> %min = arith.constant -1000.0 : f64 %max = arith.constant 1000.0 : f64 %seed = arith.constant 42 : i32 - call @fill_on_buffers(%min, %max, %seed, %O) : + call @fill_rng_on_buffers(%min, %max, %seed, %O) : (f64, f64, i32, memref<4x16xi32>) -> () %c0 = arith.constant 0 : index @@ -123,9 +152,9 @@ def transform(module, boilerplate): # TODO: Allow cloning functions from one module to another. # Atm we have to resort to string concatenation. - mod = Module.parse( - str(module.operation.regions[0].blocks[0].operations[0].operation) + - boilerplate) + ops = module.operation.regions[0].blocks[0].operations + mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate) + pm = PassManager.parse( "builtin.func(convert-linalg-to-loops, lower-affine, " + "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," + @@ -194,13 +223,21 @@ test_matmul_generic() def test_fill_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() - f64 = F64Type.get() + f32 = F32Type.get() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): - @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) - def fill_on_buffers(min, max, seed, out): - linalg.fill_rng_2d(min, max, seed, outs=[out]) + @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + def fill_0d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out]) + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + def fill_1d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out]) + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + def fill_2d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out]) execution_engine = ExecutionEngine(transform(module, fill_boiler)) @@ -212,13 +249,48 @@ def test_fill_builtin(): execution_engine.invoke("main", res) log("RESULT: ", res[0]) - # CHECK: RESULT: -480 + # CHECK: RESULT: 6 test_fill_builtin() def test_fill_generic(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + def fill_0d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out], emit_generic=True) + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + def fill_1d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out], emit_generic=True) + + @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + def fill_2d_on_buffers(value, out): + linalg.fill_tensor(value, outs=[out], emit_generic=True) + + execution_engine = ExecutionEngine(transform(module, fill_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: 6 + + +test_fill_generic() + + +def test_fill_rng_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() f64 = F64Type.get() @@ -226,10 +298,10 @@ def test_fill_generic(): with InsertionPoint(module.body): @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) - def fill_on_buffers(min, max, seed, out): - linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) + def fill_rng_on_buffers(min, max, seed, out): + linalg.fill_rng_2d(min, max, seed, outs=[out]) - execution_engine = ExecutionEngine(transform(module, fill_boiler)) + execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) # TODO: FFI-based solution to allow testing and printing with python code. # Prepare arguments: one result i32. @@ -242,7 +314,34 @@ def test_fill_generic(): # CHECK: RESULT: -480 -test_fill_generic() +test_fill_rng_builtin() + + +def test_fill_rng_generic(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) + def fill_rng_on_buffers(min, max, seed, out): + linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) + + execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: -480 + + +test_fill_rng_generic() def test_max_pooling_builtin(): diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 925b90848ac0..d5d8ba6e0db1 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -558,16 +558,63 @@ static const char structuredOpBuilderFormat[] = R"FMT( }]> )FMT"; -// The iterator_types() method implementation. Parameters: +// The iterator_types() method for structured ops. Parameters: // {0}: Class name // {1}: Comma interleaved iterator type names. static const char structuredOpIteratorTypesFormat[] = R"FMT( -ArrayAttr {0}::iterator_types() { +ArrayAttr {0}::iterator_types() {{ return Builder(getContext()).getStrArrayAttr(SmallVector{{ {1} }); } )FMT"; +// The iterator_types() method for rank polymorphic structured ops. Parameters: +// {0}: Class name +static const char rankPolyStructuredOpIteratorTypesFormat[] = + R"FMT( +ArrayAttr {0}::iterator_types() {{ + int64_t rank = getRank(getOutputOperand(0)); + return Builder(getContext()).getStrArrayAttr( + SmallVector(rank, getParallelIteratorTypeName())); +} +)FMT"; + +// The indexing_maps() method for structured ops. Parameters: +// {0}: Class name +// {1}: Comma-separated list of dimension variable names. +// {2}: Statements +static const char structuredOpIndexingMapsFormat[] = R"FMT( +ArrayAttr {0}::indexing_maps() {{ + static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; + ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); + if (cached) + return cached; + + MLIRContext *context = getContext(); + auto symbolBindings = getSymbolBindings(*this); + SmallVector maps; + {2} + cached = Builder(context).getAffineMapArrayAttr(maps); + getOperation()->setAttr(memoizeAttr, cached); + return cached; +} +)FMT"; + +// The indexing_maps() method for rank polymorphic structured ops. Parameters: +// {0}: Class name +static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT( +ArrayAttr {0}::indexing_maps() {{ + MLIRContext *context = getContext(); + AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context); + AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( + getNumParallelLoops(), context); + SmallVector indexingMaps; + for (OpOperand *opOperand : getInputAndOutputOperands()) + indexingMaps.push_back(isScalar(opOperand) ? scalarMap : tensorMap); + return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); +} +)FMT"; + // Implementations of fold and getEffects. // Parameters: // {0}: Class name @@ -681,8 +728,14 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig, return arg.usage != LinalgOperandDefUsage::attribute; }); - // Reference iterators. - { + // An operation that accesses only scalars and scalar/rank zero tensors is + // rank polymorhpic. We implement rank polymorphism by generating different + // indexing maps and iterators that match the rank of the first output tensor. + // An operation is rank polymorphic if the iteration domain has rank zero. + bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty(); + + // Generate the iterator_types() method. + if (!isRankPolymorphic) { std::string iteratorsStr; llvm::raw_string_ostream ss(iteratorsStr); llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss, @@ -699,22 +752,25 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig, ss.flush(); os << llvm::formatv(structuredOpIteratorTypesFormat, className, iteratorsStr); + } else { + os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className); } - // Static indexing maps. + // Generating the indexing_maps() method. if (auto &staticMaps = opConfig.structuredOp->indexingMaps.staticIndexingMaps) { if (staticMaps->empty()) return emitError(genContext.getLoc()) << "op has no indexing maps"; - AffineMap firstMap = staticMaps->front().affineMap(); + if (!isRankPolymorphic) { + AffineMap firstMap = staticMaps->front().affineMap(); - // Symbol bindings. - { - // For each symbol, generate a declaration for it, either with an - // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from - // an attribute). - // TODO: Possibly lift into a top-level method. - static const char structuredOpSymbolBindingsFormat[] = R"FMT( + // Symbol bindings. + { + // For each symbol, generate a declaration for it, either with an + // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from + // an attribute). + // TODO: Possibly lift into a top-level method. + static const char structuredOpSymbolBindingsFormat[] = R"FMT( static SmallVector getSymbolBindings({0} self) { MLIRContext *context = self.getContext(); SmallVector exprs; @@ -723,101 +779,83 @@ static SmallVector getSymbolBindings({0} self) { } )FMT"; - unsigned symbolCount = firstMap.getNumSymbols(); - SmallVector symbolBindings; - for (unsigned i = 0; i < symbolCount; ++i) { - symbolBindings.push_back(llvm::formatv( - " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); - } + unsigned symbolCount = firstMap.getNumSymbols(); + SmallVector symbolBindings; + for (unsigned i = 0; i < symbolCount; ++i) { + symbolBindings.push_back(llvm::formatv( + " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); + } - // Access an index attribute. Parameters: - // {0}: Attribute name - // {1}: Symbol position - // {2}: Attribute index - static const char structuredOpAccessAttrFormat[] = R"FMT( + // Access an index attribute. Parameters: + // {0}: Attribute name + // {1}: Symbol position + // {2}: Attribute index + static const char structuredOpAccessAttrFormat[] = R"FMT( int64_t cst{1} = self.{0}().getValues()[{2}]; exprs.push_back(getAffineConstantExpr(cst{1}, context)); )FMT"; - // Update all symbol bindings mapped to an attribute. - for (LinalgOperandDef &arg : opConfig.structuredOp->args) { - if (arg.usage != LinalgOperandDefUsage::attribute) - continue; - assert(arg.attributeMap.hasValue()); - for (auto &en : - llvm::enumerate(arg.attributeMap->affineMap().getResults())) { - if (auto symbol = en.value().dyn_cast()) { - symbolBindings[symbol.getPosition()] = - llvm::formatv(structuredOpAccessAttrFormat, arg.name, - symbol.getPosition(), en.index()); + // Update all symbol bindings mapped to an attribute. + for (LinalgOperandDef &arg : opConfig.structuredOp->args) { + if (arg.usage != LinalgOperandDefUsage::attribute) + continue; + assert(arg.attributeMap.hasValue()); + for (auto &en : + llvm::enumerate(arg.attributeMap->affineMap().getResults())) { + if (auto symbol = en.value().dyn_cast()) { + symbolBindings[symbol.getPosition()] = + llvm::formatv(structuredOpAccessAttrFormat, arg.name, + symbol.getPosition(), en.index()); + } } } + + std::string symbolBindingsStr; + llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); + llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); + symbolBindingsSs.flush(); + + os << llvm::formatv(structuredOpSymbolBindingsFormat, className, + symbolBindingsStr); } - std::string symbolBindingsStr; - llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); - llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); - symbolBindingsSs.flush(); + // Indexing maps. + { + unsigned dimCount = firstMap.getNumDims(); - os << llvm::formatv(structuredOpSymbolBindingsFormat, className, - symbolBindingsStr); - } + // Generate a comma-separated list of dim identifiers to be passed to + // bindDims, ensuring tht AffineExpr identifiers are bound in the right + // order to the proper AffineDimExpr. + // This results in vars in scope like: d0, d1, d2... + SmallVector dimIndices; + for (unsigned i = 0; i < dimCount; ++i) + dimIndices.push_back(i); + std::string dimIdentsStr; + llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); + llvm::interleaveComma(dimIndices, dimIdentsSs, + [&](unsigned i) { dimIdentsSs << "d" << i; }); + dimIdentsSs.flush(); - // Indexing maps. - { - // Parameters: - // {0}: Class name - // {1}: Comma-separated list of dimension variable names. - // {2}: Statements - static const char structuredOpIndexingMapsFormat[] = R"FMT( -ArrayAttr {0}::indexing_maps() { - static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; - ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); - if (cached) - return cached; + // Statements to add and simplify each affine map. + SmallVector stmts; + for (auto &indexingMap : *staticMaps) { + // TODO: Assert that dim and symbol count match the first. + stmts.push_back( + llvm::formatv("maps.push_back({0});", + generateCppExpression(indexingMap, "context"))); + stmts.push_back(llvm::formatv( + "maps.back() = " + "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " + "symbolBindings, {0}, 0));", + dimCount)); + } - MLIRContext *context = getContext(); - auto symbolBindings = getSymbolBindings(*this); - SmallVector maps; - {2} - cached = Builder(context).getAffineMapArrayAttr(maps); - getOperation()->setAttr(memoizeAttr, cached); - return cached; -} -)FMT"; - - unsigned dimCount = firstMap.getNumDims(); - - // Generate a comma-separated list of dim identifiers to be passed to - // bindDims, ensuring tht AffineExpr identifiers are bound in the right - // order to the proper AffineDimExpr. - // This results in vars in scope like: d0, d1, d2... - SmallVector dimIndices; - for (unsigned i = 0; i < dimCount; ++i) - dimIndices.push_back(i); - std::string dimIdentsStr; - llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); - llvm::interleaveComma(dimIndices, dimIdentsSs, - [&](unsigned i) { dimIdentsSs << "d" << i; }); - dimIdentsSs.flush(); - - // Statements to add and simplify each affine map. - SmallVector stmts; - for (auto &indexingMap : *staticMaps) { - // TODO: Assert that dim and symbol count match the first. - stmts.push_back( - llvm::formatv("maps.push_back({0});", - generateCppExpression(indexingMap, "context"))); - stmts.push_back(llvm::formatv( - "maps.back() = " - "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " - "symbolBindings, {0}, 0));", - dimCount)); + // TODO: This needs to be memoized and/or converted to non-parser based + // C++ codegen prior to real use. + os << llvm::formatv(structuredOpIndexingMapsFormat, className, + dimIdentsStr, interleaveToString(stmts, "\n ")); } - - // TODO: This needs to be memoized and/or converted to non-parser based - // C++ codegen prior to real use. - os << llvm::formatv(structuredOpIndexingMapsFormat, className, - dimIdentsStr, interleaveToString(stmts, "\n ")); + } else { + os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className); } } else { return emitError(genContext.getLoc())