[mlir][linalg][python] Add Python Bindings for Inferring Contraction Dimensions from Affine Maps (#167587)

This PR exposes `linalg::inferContractionDims(ArrayRef<AffineMap>)` to
Python, allowing users to infer contraction dimensions (batch/m/n/k)
directly from a list of affine maps without needing an operation.

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This commit is contained in:
Bangtian Liu
2025-11-12 13:35:04 -05:00
committed by GitHub
parent a22834a4d2
commit a5a78d0bb4
4 changed files with 102 additions and 3 deletions

View File

@@ -10,6 +10,7 @@
#ifndef MLIR_C_DIALECT_LINALG_H
#define MLIR_C_DIALECT_LINALG_H
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
@@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions {
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op);
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps,
size_t numMaps);
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
typedef struct MlirLinalgConvolutionDimensions {

View File

@@ -80,6 +80,28 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"op.",
nb::arg("op"));
m.def(
"infer_contraction_dimensions_from_maps",
[](std::vector<MlirAffineMap> indexingMaps)
-> std::optional<MlirLinalgContractionDimensions> {
if (indexingMaps.empty())
return std::nullopt;
MlirLinalgContractionDimensions dims =
mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
indexingMaps.size());
// Detect "empty" result from invalid input or failed inference.
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
return dims;
},
"Infers contraction dimensions (batch/m/n/k) from a list of affine "
"maps.",
nb::arg("indexing_maps"));
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
"Checks if the given operation is a Linalg convolution operation.",
nb::arg("op"));

View File

@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Linalg.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -62,9 +63,8 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
const linalg::ContractionDimensions &contractionDims = *maybeDims;
MLIRContext *ctx = linalgOp.getContext();
auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
return wrap(
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
};
result.batch = toAttr(contractionDims.batch);
@@ -75,6 +75,38 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
return result;
}
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps,
size_t numMaps) {
MlirLinalgContractionDimensions result{};
if (!indexingMaps || numMaps == 0)
return result;
SmallVector<AffineMap, 3> maps;
maps.reserve(numMaps);
for (size_t i = 0; i < numMaps; ++i) {
maps.push_back(unwrap(indexingMaps[i]));
}
FailureOr<linalg::ContractionDimensions> maybeDims =
linalg::inferContractionDims(maps);
if (failed(maybeDims))
return result;
MLIRContext *ctx = maps[0].getContext();
auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
};
result.batch = toAttr(maybeDims->batch);
result.m = toAttr(maybeDims->m);
result.n = toAttr(maybeDims->n);
result.k = toAttr(maybeDims->k);
return result;
}
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp)

View File

@@ -208,3 +208,43 @@ def test_get_indexing_maps_attr():
assert maps[0] == a_map
assert maps[1] == b_map
assert maps[2] == c_map
@run
def test_infer_contraction_dimensions_from_maps():
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
# === Test valid contraction (matmul) ===
dim_m = AffineDimExpr.get(0)
dim_n = AffineDimExpr.get(1)
dim_k = AffineDimExpr.get(2)
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map])
assert dims is not None
# Expect m=[0], n=[1], k=[2] as per standard matmul.
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}"
# === Test invalid input (wrong number of maps) ===
invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map])
assert invalid_dims is None
# === Test element-wise operation ===
dim_i = AffineDimExpr.get(0)
dim_j = AffineDimExpr.get(1)
elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j])
elementwise_dims = linalg.infer_contraction_dimensions_from_maps(
[elementwise_map, elementwise_map, elementwise_map]
)
assert elementwise_dims is not None
assert len(elementwise_dims.m) == 0
assert len(elementwise_dims.n) == 0
assert len(elementwise_dims.k) == 0
assert list(elementwise_dims.batch) == [0, 1]