diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index df4af026d0c5..b0c705a87a35 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -43,6 +43,9 @@ LogicalResult verifyContractionInterface(Operation *op); /// Verify that `op` conforms to the ConvolutionOpInterface. LogicalResult verifyConvolutionInterface(Operation *op); +/// Verify that `op` conforms to the FillOpInterface. +LogicalResult verifyFillInterface(Operation *op); + /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index dbf65aec9788..4ac6bb78653c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -132,6 +132,50 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> { ]; } +def LinalgFillOpInterface : OpInterface<"FillOpInterface"> { + let description = [{ + A fill operation is defined in general terms: + 1. Has a scalar `value` operand. + 2. Has one `output` operand. + }]; + let cppNamespace = "::mlir::linalg"; + let verify = [{ return detail::verifyFillInterface($_op); }]; + let methods = [ + InterfaceMethod< + /*desc=*/"Return the fill value.", + /*retTy=*/"Value", + /*methodName=*/"value", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperation()->getOperand(0); + }] + >, + InterfaceMethod< + /*desc=*/"Return the output operand.", + /*retTy=*/"Value", + /*methodName=*/"output", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperation()->getOperand(1); + }] + >, + InterfaceMethod< + /*desc=*/"Return the result.", + /*retTy=*/"Value", + /*methodName=*/"result", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if ($_op.getOperation()->getResults().empty()) + return nullptr; + return $_op.getOperation()->getResults().front(); + }] + >, + ]; +} + // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let cppNamespace = "::mlir::linalg"; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index e29600460367..7511e268ae85 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2875,6 +2875,8 @@ metadata: !LinalgOpMetadata Works for arbitrary ranked output tensors since the operation performs scalar accesses only and is thus rank polymorphic. Numeric casting is performed on the value operand, promoting it to the same data type as the output. + implements: + - LinalgFillOpInterface structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 84e26b150fa3..4c796723c25a 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -408,6 +408,44 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { } return success(); } + +//===----------------------------------------------------------------------===// +// FillOpInterface implementation +//===----------------------------------------------------------------------===// + +enum class MatchFillResult { + Success = 0, + NotLinalgOp, + WrongNumOperands, + NotScalarInput +}; + +static MatchFillResult isFillInterfaceImpl(Operation *op) { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return MatchFillResult::NotLinalgOp; + if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1) + return MatchFillResult::WrongNumOperands; + + OpOperand *value = linalgOp.getInputOperand(0); + if (!linalgOp.isScalar(value)) + return MatchFillResult::NotScalarInput; + + return MatchFillResult::Success; +} + +LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { + auto res = isFillInterfaceImpl(op); + if (res == MatchFillResult::NotLinalgOp) + return op->emitError("expected a LinalgOp"); + if (res == MatchFillResult::WrongNumOperands) + return op->emitError("expected op with 1 input and 1 output"); + if (res == MatchFillResult::NotScalarInput) + return op->emitError("expected op with scalar input"); + + return success(); +} + //===----------------------------------------------------------------------===// // StructuredOpInterface implementation //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 7de0a76e87b7..1de5449e27e3 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -686,6 +686,7 @@ class OpInterfaceDef: ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") +FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") class OpMetadataDef(YAMLObject): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 0ef40613a7ba..7798d7f9498e 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -671,6 +671,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): accesses only and is thus rank polymorphic. Numeric casting is performed on the value operand, promoting it to the same data type as the output. """ + implements(FillOpInterface) O[None] = TypeFn.cast_signed(U, value) diff --git a/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir new file mode 100644 index 000000000000..17a5f119cfd5 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fill-interface-invalid.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +func @test_fill_op_not_linalg_op(%arg0 : f32, %arg1 : tensor) + -> tensor { + // expected-error @+1 {{expected a LinalgOp}} + %0 = "test.fill_op_not_linalg_op"(%arg0, %arg1) + : (f32, tensor) -> tensor + return %0 : tensor +} + +// ----- + +#map0 = affine_map<(d0) -> ()> +#map1 = affine_map<(d0) -> (d0)> +func @test_fill_op_wrong_num_operands(%arg0 : f32, %arg1 : tensor) + -> tensor { + // expected-error @+1 {{expected op with 1 input and 1 output}} + %0 = test.linalg_fill_op { + indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel"]} + ins(%arg0, %arg0 : f32, f32) outs(%arg1 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + linalg.yield %arg2 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +#map1 = affine_map<(d0) -> (d0)> +func @test_fill_op_non_scalar_input(%arg0 : tensor, + %arg1 : tensor) -> tensor { + // expected-error @+1 {{expected op with scalar input}} + %0 = test.linalg_fill_op { + indexing_maps = [#map1, #map1], + iterator_types = ["parallel"]} + ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } -> tensor + return %0 : tensor +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 9ab7d9e62e67..68139eb55533 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2640,6 +2640,64 @@ def TestLinalgConvOp : }]; } +//===----------------------------------------------------------------------===// +// Test LinalgFillOpInterface. +//===----------------------------------------------------------------------===// + +def TestLinalgFillOpNotLinalgOp : TEST_Op<"fill_op_not_linalg_op", [ + LinalgFillOpInterface]> { + let arguments = (ins + AnyType:$value, AnyType:$output); + let results = (outs AnyRankedTensor:$result); +} + +def TestLinalgFillOp : + TEST_Op<"linalg_fill_op", [AttrSizedOperandSegments, SingleBlock, + LinalgStructuredInterface, LinalgFillOpInterface]> { + + let arguments = (ins Variadic:$inputs, + Variadic:$outputs); + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); + + let assemblyFormat = [{ + attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)? + `outs` `(` $outputs `:` type($outputs) `)` + $region (`->` type($results)^)? + }]; + + let extraClassDeclaration = [{ + bool hasIndexSemantics() { return false; } + + static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, + mlir::ArrayRef attrs) { + b.create(block.getArguments().back()); + } + + static std::function)> + getRegionBuilder() { + return ®ionBuilder; + } + + mlir::ArrayAttr iterator_types() { + return getOperation()->getAttrOfType("iterator_types"); + } + + mlir::ArrayAttr indexing_maps() { + return getOperation()->getAttrOfType("indexing_maps"); + } + + std::string getLibraryCallName() { + return ""; + } + + // To conform with interface requirement on operand naming. + mlir::ValueRange inputs() { return getInputs(); } + mlir::ValueRange outputs() { return getOutputs(); } + }]; +} + //===----------------------------------------------------------------------===// // Test Ops with Default-Valued String Attributes //===----------------------------------------------------------------------===//