mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[mlir][linalg][python] Add Python Bindings for Inferring Contraction Dimensions from Affine Maps (#167587)
This PR exposes `linalg::inferContractionDims(ArrayRef<AffineMap>)` to Python, allowing users to infer contraction dimensions (batch/m/n/k) directly from a list of affine maps without needing an operation. --------- Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#ifndef MLIR_C_DIALECT_LINALG_H
|
||||
#define MLIR_C_DIALECT_LINALG_H
|
||||
|
||||
#include "mlir-c/AffineMap.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
@@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions {
|
||||
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
|
||||
mlirLinalgInferContractionDimensions(MlirOperation op);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
|
||||
mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps,
|
||||
size_t numMaps);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
|
||||
|
||||
typedef struct MlirLinalgConvolutionDimensions {
|
||||
|
||||
@@ -80,6 +80,28 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
|
||||
"op.",
|
||||
nb::arg("op"));
|
||||
|
||||
m.def(
|
||||
"infer_contraction_dimensions_from_maps",
|
||||
[](std::vector<MlirAffineMap> indexingMaps)
|
||||
-> std::optional<MlirLinalgContractionDimensions> {
|
||||
if (indexingMaps.empty())
|
||||
return std::nullopt;
|
||||
|
||||
MlirLinalgContractionDimensions dims =
|
||||
mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
|
||||
indexingMaps.size());
|
||||
|
||||
// Detect "empty" result from invalid input or failed inference.
|
||||
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
|
||||
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return dims;
|
||||
},
|
||||
"Infers contraction dimensions (batch/m/n/k) from a list of affine "
|
||||
"maps.",
|
||||
nb::arg("indexing_maps"));
|
||||
|
||||
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
|
||||
"Checks if the given operation is a Linalg convolution operation.",
|
||||
nb::arg("op"));
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir-c/Dialect/Linalg.h"
|
||||
#include "mlir/CAPI/AffineMap.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
|
||||
@@ -62,9 +63,8 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
|
||||
const linalg::ContractionDimensions &contractionDims = *maybeDims;
|
||||
MLIRContext *ctx = linalgOp.getContext();
|
||||
|
||||
auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
|
||||
return wrap(
|
||||
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
|
||||
auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
|
||||
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
|
||||
};
|
||||
|
||||
result.batch = toAttr(contractionDims.batch);
|
||||
@@ -75,6 +75,38 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
|
||||
return result;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
|
||||
mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps,
|
||||
size_t numMaps) {
|
||||
MlirLinalgContractionDimensions result{};
|
||||
if (!indexingMaps || numMaps == 0)
|
||||
return result;
|
||||
|
||||
SmallVector<AffineMap, 3> maps;
|
||||
maps.reserve(numMaps);
|
||||
for (size_t i = 0; i < numMaps; ++i) {
|
||||
maps.push_back(unwrap(indexingMaps[i]));
|
||||
}
|
||||
|
||||
FailureOr<linalg::ContractionDimensions> maybeDims =
|
||||
linalg::inferContractionDims(maps);
|
||||
if (failed(maybeDims))
|
||||
return result;
|
||||
|
||||
MLIRContext *ctx = maps[0].getContext();
|
||||
|
||||
auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
|
||||
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
|
||||
};
|
||||
|
||||
result.batch = toAttr(maybeDims->batch);
|
||||
result.m = toAttr(maybeDims->m);
|
||||
result.n = toAttr(maybeDims->n);
|
||||
result.k = toAttr(maybeDims->k);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
|
||||
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
|
||||
if (!linalgOp)
|
||||
|
||||
@@ -208,3 +208,43 @@ def test_get_indexing_maps_attr():
|
||||
assert maps[0] == a_map
|
||||
assert maps[1] == b_map
|
||||
assert maps[2] == c_map
|
||||
|
||||
|
||||
@run
|
||||
def test_infer_contraction_dimensions_from_maps():
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
# === Test valid contraction (matmul) ===
|
||||
dim_m = AffineDimExpr.get(0)
|
||||
dim_n = AffineDimExpr.get(1)
|
||||
dim_k = AffineDimExpr.get(2)
|
||||
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
|
||||
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
|
||||
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
|
||||
|
||||
dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map])
|
||||
assert dims is not None
|
||||
|
||||
# Expect m=[0], n=[1], k=[2] as per standard matmul.
|
||||
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
|
||||
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
|
||||
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
|
||||
assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}"
|
||||
|
||||
# === Test invalid input (wrong number of maps) ===
|
||||
invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map])
|
||||
assert invalid_dims is None
|
||||
|
||||
# === Test element-wise operation ===
|
||||
dim_i = AffineDimExpr.get(0)
|
||||
dim_j = AffineDimExpr.get(1)
|
||||
elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j])
|
||||
elementwise_dims = linalg.infer_contraction_dimensions_from_maps(
|
||||
[elementwise_map, elementwise_map, elementwise_map]
|
||||
)
|
||||
assert elementwise_dims is not None
|
||||
assert len(elementwise_dims.m) == 0
|
||||
assert len(elementwise_dims.n) == 0
|
||||
assert len(elementwise_dims.k) == 0
|
||||
assert list(elementwise_dims.batch) == [0, 1]
|
||||
|
||||
Reference in New Issue
Block a user