diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 339e63d667c5..003b0cde3965 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -10,6 +10,7 @@ #ifndef MLIR_C_DIALECT_LINALG_H #define MLIR_C_DIALECT_LINALG_H +#include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" @@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions { MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensions(MlirOperation op); +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, + size_t numMaps); + MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op); typedef struct MlirLinalgConvolutionDimensions { diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 015502371c65..0b079b404d42 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -80,6 +80,28 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { "op.", nb::arg("op")); + m.def( + "infer_contraction_dimensions_from_maps", + [](std::vector indexingMaps) + -> std::optional { + if (indexingMaps.empty()) + return std::nullopt; + + MlirLinalgContractionDimensions dims = + mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(), + indexingMaps.size()); + + // Detect "empty" result from invalid input or failed inference. + if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) && + mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) { + return std::nullopt; + } + return dims; + }, + "Infers contraction dimensions (batch/m/n/k) from a list of affine " + "maps.", + nb::arg("indexing_maps")); + m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp, "Checks if the given operation is a Linalg convolution operation.", nb::arg("op")); diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 5c2a65d2c4c8..75c811aed6cc 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/Linalg.h" +#include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -62,9 +63,8 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { const linalg::ContractionDimensions &contractionDims = *maybeDims; MLIRContext *ctx = linalgOp.getContext(); - auto toAttr = [&ctx](const SmallVector &vals) -> MlirAttribute { - return wrap( - DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + auto toAttr = [ctx](ArrayRef vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); }; result.batch = toAttr(contractionDims.batch); @@ -75,6 +75,38 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { return result; } +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, + size_t numMaps) { + MlirLinalgContractionDimensions result{}; + if (!indexingMaps || numMaps == 0) + return result; + + SmallVector maps; + maps.reserve(numMaps); + for (size_t i = 0; i < numMaps; ++i) { + maps.push_back(unwrap(indexingMaps[i])); + } + + FailureOr maybeDims = + linalg::inferContractionDims(maps); + if (failed(maybeDims)) + return result; + + MLIRContext *ctx = maps[0].getContext(); + + auto toAttr = [ctx](ArrayRef vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + }; + + result.batch = toAttr(maybeDims->batch); + result.m = toAttr(maybeDims->m); + result.n = toAttr(maybeDims->n); + result.k = toAttr(maybeDims->k); + + return result; +} + MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) { auto linalgOp = llvm::dyn_cast(unwrap(op)); if (!linalgOp) diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py index 5f7cb6a6c83c..8ab53b4e2874 100644 --- a/mlir/test/python/dialects/linalg/utils.py +++ b/mlir/test/python/dialects/linalg/utils.py @@ -208,3 +208,43 @@ def test_get_indexing_maps_attr(): assert maps[0] == a_map assert maps[1] == b_map assert maps[2] == c_map + + +@run +def test_infer_contraction_dimensions_from_maps(): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + # === Test valid contraction (matmul) === + dim_m = AffineDimExpr.get(0) + dim_n = AffineDimExpr.get(1) + dim_k = AffineDimExpr.get(2) + a_map = AffineMap.get(3, 0, [dim_m, dim_k]) + b_map = AffineMap.get(3, 0, [dim_k, dim_n]) + c_map = AffineMap.get(3, 0, [dim_m, dim_n]) + + dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map]) + assert dims is not None + + # Expect m=[0], n=[1], k=[2] as per standard matmul. + assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" + assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" + assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" + assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}" + + # === Test invalid input (wrong number of maps) === + invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map]) + assert invalid_dims is None + + # === Test element-wise operation === + dim_i = AffineDimExpr.get(0) + dim_j = AffineDimExpr.get(1) + elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j]) + elementwise_dims = linalg.infer_contraction_dimensions_from_maps( + [elementwise_map, elementwise_map, elementwise_map] + ) + assert elementwise_dims is not None + assert len(elementwise_dims.m) == 0 + assert len(elementwise_dims.n) == 0 + assert len(elementwise_dims.k) == 0 + assert list(elementwise_dims.batch) == [0, 1]