From d8dc1c22bf926cb8c87d7ff72bae6aafe076bbc2 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Fri, 7 Jul 2023 12:04:30 +0100 Subject: [PATCH] [MLIR][Linalg] Add max named op to linalg I've been trying to come up with a simple and clean implementation for ReLU. TOSA uses `clamp` which is probably the goal, but that means table-gen to make it efficient (attributes, only lower `min` or `max`). For now, `max` is a reasonable named op despite ReLU, so we can start using it for tiling and fusion, and upon success, we create a more complete op `clamp` that doesn't need a whole tensor filled with zeroes or ones to implement the different activation functions. As with other named ops, we start "requiring" type casts and broadcasts, and zero filled constant tensors to a more complex pattern-matcher, and can slowly simplify with attributes or structured matchers (ex. PDL) in the future. Differential Revision: https://reviews.llvm.org/D154703 --- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 49 +++++++++++++++++++ .../linalg/opdsl/ops/core_named_ops.py | 19 +++++++ .../Dialect/Linalg/generalize-named-ops.mlir | 25 ++++++++++ mlir/test/Dialect/Linalg/named-ops-fail.mlir | 16 ++++++ mlir/test/Dialect/Linalg/named-ops.mlir | 34 +++++++++++++ 5 files changed, 143 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 11fa49ef3468..d021376ff4cd 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -613,6 +613,55 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: rhs --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: max + cpp_class_name: MaxOp + doc: |- + Takes the max (signed) between the input and a constant. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: lhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: rhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: out + kind: output_tensor + type_var: T + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: out + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: max_signed + operands: + - !ScalarExpression + scalar_arg: lhs + - !ScalarExpression + scalar_arg: rhs +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: matmul cpp_class_name: MatmulOp diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 9cc252eb7102..e4512cd1e057 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -219,6 +219,25 @@ def div_unsigned( O[None] = lhs[None] / rhs[None] +@linalg_structured_op +def max( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Takes the max (signed) between two inputs, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.max_signed(lhs[None], rhs[None]) + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index 7e96ad2b0b24..af616a0a7bd8 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -537,3 +537,28 @@ func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) // CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32) // CHECK-NEXT: %[[negf:.+]] = arith.negf %[[BBARG0]] : f32 // CHECK-NEXT: linalg.yield %[[negf]] : f32 + +// ----- + +func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, + %out: memref<7x14x21xf32>) { + linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) + outs(%out : memref<7x14x21xf32>) + return +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK: func @generalize_max +// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, +// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) +// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) +// CHECK-NEXT: %[[max:.+]] = arith.maxf %[[BBARG0]], %[[BBARG1]] : f32 +// CHECK-NEXT: linalg.yield %[[max]] : f32 diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir index b5cd9b659dd1..c351e139a97e 100644 --- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir +++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir @@ -173,3 +173,19 @@ func.func @negf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) { linalg.negf ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>) return } + +// ----- + +func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { + // CHECK: op requires the same type for all operands and results + linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +func.func @max_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) + linalg.max ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index ca4350e30566..8f00d5465532 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1540,3 +1540,37 @@ func.func @negf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { %1 = linalg.negf ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> return %1 : tensor<4x8x16xf32> } + +// ----- + +// CHECK-LABEL: func @max_dynamic +func.func @max_dynamic(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK: linalg.max + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref) + // CHECK-SAME: outs(%{{.+}} : memref) + linalg.max ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +// CHECK-LABEL: func @max_static +func.func @max_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: linalg.max + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>) + linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @max_tensor +func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { + %0 = tensor.empty() : tensor<4x8x16xf32> + // CHECK: linalg.max + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>) + %1 = linalg.max ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + return %1 : tensor<4x8x16xf32> +}