mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[mlir][LLVMIR] Fix fusion for rank-0 tensors
Summary: This diff fixes fusion craching for ops with rank-0 tensors Reviewers: mravishankar, nicolasvasilache, rriddle! Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D76479
This commit is contained in:
@@ -257,7 +257,8 @@ AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
|
||||
results.push_back(
|
||||
expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
|
||||
|
||||
return get(numResultDims, numResultSyms, results);
|
||||
return results.empty() ? get(numResultDims, 0, getContext())
|
||||
: get(numResultDims, numResultSyms, results);
|
||||
}
|
||||
|
||||
AffineMap AffineMap::compose(AffineMap map) {
|
||||
|
||||
@@ -105,3 +105,28 @@ func @add_broadcast_mul_fusion(%arg0: tensor<?xf32>, %arg1 : tensor<?xf32>, %arg
|
||||
}: tensor<?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
return %2 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[MAP0:.*]] = affine_map<() -> ()>
|
||||
#map0 = affine_map<() -> ()>
|
||||
|
||||
// CHECK-LABEL: @add_mul_scalar_fusion
|
||||
func @add_mul_scalar_fusion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32>
|
||||
{
|
||||
%0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = []} %arg0, %arg1 {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
%1 = addf %arg3, %arg4 : f32
|
||||
linalg.yield %1 : f32
|
||||
}: tensor<f32>, tensor<f32> -> tensor<f32>
|
||||
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64
|
||||
// CHECK: addf
|
||||
// CHECK: mulf
|
||||
%1 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = []} %0, %arg2 {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
%1 = mulf %arg3, %arg4 : f32
|
||||
linalg.yield %1 : f32
|
||||
}: tensor<f32>, tensor<f32> -> tensor<f32>
|
||||
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user