mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 17:45:07 +08:00
[mlir][tosa] Tosa elementwise broadcasting had some minor bugs
Updated tests to include broadcast of left and right. Includes bypass if in-type and out-type match shape (no broadcasting). Differential Revision: https://reviews.llvm.org/D102276
This commit is contained in:
@@ -522,8 +522,12 @@ static LogicalResult
|
||||
elementwiseMatchAndRewriteHelper(Operation *operation,
|
||||
PatternRewriter &rewriter) {
|
||||
auto loc = operation->getLoc();
|
||||
|
||||
assert(operation->getNumResults() == 1 &&
|
||||
"All TOSA elementwise ops should only return a single result.");
|
||||
|
||||
auto results = operation->getResults();
|
||||
auto resultTy = operation->getOperand(0).getType().dyn_cast<ShapedType>();
|
||||
auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
|
||||
|
||||
if (!resultTy)
|
||||
return rewriter.notifyMatchFailure(operation,
|
||||
@@ -531,9 +535,6 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
|
||||
|
||||
unsigned rank = resultTy.getRank();
|
||||
|
||||
assert(operation->getNumResults() == 1 &&
|
||||
"All TOSA elementwise ops should only return a single result.");
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
SmallVector<Type> bodyArgTypes;
|
||||
|
||||
@@ -565,11 +566,18 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
|
||||
// Input indexing maps may be broadcasted.
|
||||
for (Value operand : operation->getOperands()) {
|
||||
ShapedType type = operand.getType().cast<ShapedType>();
|
||||
|
||||
if (type.getShape() == resultTy.getShape()) {
|
||||
operands.push_back(operand);
|
||||
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 5> newShape;
|
||||
SmallVector<AffineExpr, 4> affineExprs;
|
||||
newShape.reserve(type.getRank());
|
||||
for (auto it : llvm::enumerate(type.getShape())) {
|
||||
if (it.value() != 1) {
|
||||
if (it.value() == resultTy.getDimSize(it.index())) {
|
||||
newShape.push_back(it.value());
|
||||
affineExprs.push_back(
|
||||
mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
|
||||
|
||||
@@ -73,6 +73,24 @@ func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
|
||||
|
||||
// CHECK-LABEL: @test_broadcast_swapped_args
|
||||
func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
|
||||
// CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
|
||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg1
|
||||
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[RESHAPE]] : tensor<2xf32>, tensor<f32>) outs([[INIT]] : tensor<2xf32>) {
|
||||
// CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
|
||||
// CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
|
||||
// CHECK: linalg.yield [[ELEMENT]] : f32
|
||||
// CHECK: } -> tensor<2xf32>
|
||||
%0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
|
||||
Reference in New Issue
Block a user