diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index 73ccffdf80fb..66cd8c002ff2 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -161,31 +161,34 @@ void macRegionBuilder(ArrayRef args); /// Unary pointwise operation (with broadcast) entry point. using UnaryPointwiseOpBuilder = function_ref; -Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, - StructuredIndexed I, StructuredIndexed O); +Operation *linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp, + StructuredIndexed I, StructuredIndexed O); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = tanh(I)`. The client is responsible for specifying the proper /// indexings when creating the StructuredIndexed. -Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O); +Operation *linalg_generic_pointwise_tanh(StructuredIndexed I, + StructuredIndexed O); /// Binary pointwise operation (with broadcast) entry point. using BinaryPointwiseOpBuilder = function_ref; -Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, - StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); +Operation *linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp, + StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = I1 + I2`. The client is responsible for specifying the proper /// indexings when creating the StructuredIndexed. -Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); +Operation *linalg_generic_pointwise_add(StructuredIndexed I1, + StructuredIndexed I2, + StructuredIndexed O); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = max(I1, I2)`. The client is responsible for specifying the /// proper indexings when creating the StructuredIndexed. -Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); +Operation *linalg_generic_pointwise_max(StructuredIndexed I1, + StructuredIndexed I2, + StructuredIndexed O); // TODO(ntv): Implement more useful pointwise operations on a per-need basis. @@ -198,8 +201,9 @@ using MatmulRegionBuilder = function_ref args)>; /// | /// | C(m, n) += A(m, k) * B(k, n) /// ``` -Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, - MatmulRegionBuilder regionBuilder = macRegionBuilder); +Operation * +linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, + MatmulRegionBuilder regionBuilder = macRegionBuilder); /// Build a linalg.generic, under the current ScopedContext, at the current /// insert point, that computes: @@ -209,8 +213,9 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, /// | C(m, n) = sum_k(A(m, k) * B(k, n)) /// ``` /// and returns the tensor `C`. -Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC, - MatmulRegionBuilder regionBuilder = mulRegionBuilder); +Operation * +linalg_generic_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC, + MatmulRegionBuilder regionBuilder = mulRegionBuilder); /// Build a linalg.generic, under the current ScopedContext, at the current /// insert point, that computes: @@ -220,15 +225,17 @@ Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC, /// | D(m, n) = C(m, n) + sum_k(A(m, k) * B(k, n)) /// ``` /// and returns the tensor `D`. -Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, - RankedTensorType tD, - MatmulRegionBuilder regionBuilder = macRegionBuilder); +Operation * +linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, + RankedTensorType tD, + MatmulRegionBuilder regionBuilder = macRegionBuilder); template -Operation *linalg_matmul(Container values, - MatmulRegionBuilder regionBuilder = macRegionBuilder) { +Operation * +linalg_generic_matmul(Container values, + MatmulRegionBuilder regionBuilder = macRegionBuilder) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_matmul(values[0], values[1], values[2], regionBuilder); + return linalg_generic_matmul(values[0], values[1], values[2], regionBuilder); } /// Build a linalg.generic, under the current ScopedContext, at the current @@ -253,15 +260,17 @@ Operation *linalg_matmul(Container values, /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// // TODO(ntv) Extend convolution rank with some template magic. -Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, - ArrayRef strides = {}, - ArrayRef dilations = {}); +Operation *linalg_generic_conv_nhwc(ValueHandle vI, ValueHandle vW, + ValueHandle vO, ArrayRef strides = {}, + ArrayRef dilations = {}); template -Operation *linalg_conv_nhwc(Container values, ArrayRef strides = {}, - ArrayRef dilations = {}) { +Operation *linalg_generic_conv_nhwc(Container values, + ArrayRef strides = {}, + ArrayRef dilations = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations); + return linalg_generic_conv_nhwc(values[0], values[1], values[2], strides, + dilations); } /// Build a linalg.generic, under the current ScopedContext, at the current @@ -286,18 +295,20 @@ Operation *linalg_conv_nhwc(Container values, ArrayRef strides = {}, /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// // TODO(ntv) Extend convolution rank with some template magic. -Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW, - ValueHandle vO, int depth_multiplier = 1, - ArrayRef strides = {}, - ArrayRef dilations = {}); +Operation *linalg_generic_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW, + ValueHandle vO, + int depth_multiplier = 1, + ArrayRef strides = {}, + ArrayRef dilations = {}); template -Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier, - ArrayRef strides = {}, - ArrayRef dilations = {}) { +Operation *linalg_generic_dilated_conv_nhwc(Container values, + int depth_multiplier, + ArrayRef strides = {}, + ArrayRef dilations = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_dilated_conv_nhwc(values[0], values[1], values[2], - depth_multiplier, strides, dilations); + return linalg_generic_dilated_conv_nhwc(values[0], values[1], values[2], + depth_multiplier, strides, dilations); } } // namespace ops diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h index 98ff016182fa..dedc18934b84 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -26,7 +26,10 @@ ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) { namespace intrinsics { using linalg_copy = OperationBuilder; +using linalg_dot = OperationBuilder; using linalg_fill = OperationBuilder; +using linalg_matmul = OperationBuilder; +using linalg_matvec = OperationBuilder; using linalg_range = ValueBuilder; using linalg_reshape = ValueBuilder; using linalg_slice = ValueBuilder; diff --git a/mlir/include/mlir/Dialect/Vector/EDSC/Builders.h b/mlir/include/mlir/Dialect/Vector/EDSC/Builders.h index 024ae93a8f32..396053f63213 100644 --- a/mlir/include/mlir/Dialect/Vector/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Vector/EDSC/Builders.h @@ -44,7 +44,7 @@ Value vector_contraction(StructuredIndexed A, StructuredIndexed B, /// Prerequisites: /// A, B and C capture values of proper vector types. For instance /// `A: vector<4x8xf32>`, `B: vector<8x16f32>` and `C: vector<4x16xf32>`. -Value vector_matmul(Value A, Value B, Value C); +Value vector_contraction_matmul(Value A, Value B, Value C); } // namespace ops } // namespace edsc diff --git a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h index c307721607df..79ab479c6133 100644 --- a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h @@ -16,6 +16,7 @@ namespace intrinsics { using vector_broadcast = ValueBuilder; using vector_contract = ValueBuilder; +using vector_matmul = ValueBuilder; using vector_print = OperationBuilder; } // namespace intrinsics diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 198c7fc698dd..10c18107fd8e 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -221,9 +221,8 @@ void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { linalg_yield((c + a * b).getValue()); } -Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, - StructuredIndexed I, - StructuredIndexed O) { +Operation *mlir::edsc::ops::linalg_generic_pointwise( + UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), IteratorType::Parallel); if (O.getType().isa()) { @@ -242,18 +241,17 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); } -Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, - StructuredIndexed O) { +Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I, + StructuredIndexed O) { UnaryPointwiseOpBuilder unOp( [](ValueHandle a) -> Value { return std_tanh(a); }); - return linalg_pointwise(unOp, I, O); + return linalg_generic_pointwise(unOp, I, O); } /// Binary pointwise operation (with broadcast) entry point. -Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, - StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { +Operation *mlir::edsc::ops::linalg_generic_pointwise( + BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, + StructuredIndexed I2, StructuredIndexed O) { SmallVector iterTypes(O.getExprs().size(), IteratorType::Parallel); if (O.getType().isa()) { @@ -272,28 +270,29 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); } -Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { +Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1, + StructuredIndexed I2, + StructuredIndexed O) { using edsc::op::operator+; BinaryPointwiseOpBuilder binOp( [](ValueHandle a, ValueHandle b) -> Value { return a + b; }); - return linalg_pointwise(binOp, I1, I2, O); + return linalg_generic_pointwise(binOp, I1, I2, O); } -Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { +Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1, + StructuredIndexed I2, + StructuredIndexed O) { BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value { using edsc::op::operator>; return std_select(a > b, a, b).getValue(); }); - return linalg_pointwise(binOp, I1, I2, O); + return linalg_generic_pointwise(binOp, I1, I2, O); } -Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC, - MatmulRegionBuilder regionBuilder) { +Operation * +mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB, + ValueHandle vC, + MatmulRegionBuilder regionBuilder) { // clang-format off AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); @@ -306,9 +305,10 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, // clang-format on } -Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, - RankedTensorType tC, - MatmulRegionBuilder regionBuilder) { +Operation * +mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB, + RankedTensorType tC, + MatmulRegionBuilder regionBuilder) { // clang-format off AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); @@ -321,9 +321,10 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, // clang-format on } -Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC, RankedTensorType tD, - MatmulRegionBuilder regionBuilder) { +Operation * +mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB, + ValueHandle vC, RankedTensorType tD, + MatmulRegionBuilder regionBuilder) { // clang-format off AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); @@ -336,10 +337,11 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, // clang-format on } -Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, - ValueHandle vO, - ArrayRef strides, - ArrayRef dilations) { +Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(ValueHandle vI, + ValueHandle vW, + ValueHandle vO, + ArrayRef strides, + ArrayRef dilations) { MLIRContext *ctx = ScopedContext::getContext(); // TODO(ntv) some template magic to make everything rank-polymorphic. assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm"); @@ -370,7 +372,7 @@ Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, // clang-format on } -Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc( +Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc( ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier, ArrayRef strides, ArrayRef dilations) { MLIRContext *ctx = ScopedContext::getContext(); diff --git a/mlir/lib/Dialect/Vector/EDSC/Builders.cpp b/mlir/lib/Dialect/Vector/EDSC/Builders.cpp index b1f94655ab28..d2436ef8ae62 100644 --- a/mlir/lib/Dialect/Vector/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Vector/EDSC/Builders.cpp @@ -30,7 +30,7 @@ Value mlir::edsc::ops::vector_contraction( ArrayRef{functional::map(toString, iteratorTypes)}); } -Value mlir::edsc::ops::vector_matmul(Value A, Value B, Value C) { +Value mlir::edsc::ops::vector_contraction_matmul(Value A, Value B, Value C) { AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); return vector_contraction(StructuredIndexed(A, {m, k}), diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 9b578aca229a..0c725e98fa3b 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -804,7 +804,7 @@ TEST_FUNC(affine_if_op) { } // clang-format off -// CHECK-LABEL: func @linalg_pointwise +// CHECK-LABEL: func @linalg_generic_pointwise // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} @@ -822,14 +822,14 @@ TEST_FUNC(affine_if_op) { // CHECK: tanh // CHECK: }: memref, memref // clang-format on -TEST_FUNC(linalg_pointwise_test) { +TEST_FUNC(linalg_generic_pointwise_test) { using namespace edsc; using namespace edsc::ops; auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get( {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); - auto f = makeFunction("linalg_pointwise", {}, + auto f = makeFunction("linalg_generic_pointwise", {}, {memrefType, memrefType, memrefType}); OpBuilder builder(f.getBody()); @@ -838,16 +838,16 @@ TEST_FUNC(linalg_pointwise_test) { AffineExpr i, j; bindDims(&globalContext(), i, j); StructuredIndexed SA(A), SB(B), SC(C); - linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); - linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); - linalg_pointwise_tanh(SA({i, j}), SC({i, j})); + linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j})); f.print(llvm::outs()); f.erase(); } // clang-format off -// CHECK-LABEL: func @linalg_matmul +// CHECK-LABEL: func @linalg_generic_matmul // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} @@ -857,7 +857,7 @@ TEST_FUNC(linalg_pointwise_test) { // CHECK: linalg.yield %[[a4]] : f32 // CHECK: }: memref, memref, memref // clang-format on -TEST_FUNC(linalg_matmul_test) { +TEST_FUNC(linalg_generic_matmul_test) { using namespace edsc; using namespace edsc::ops; @@ -865,18 +865,18 @@ TEST_FUNC(linalg_matmul_test) { auto memrefType = MemRefType::get( {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); auto f = - makeFunction("linalg_matmul", {}, {memrefType, memrefType, memrefType}); + makeFunction("linalg_generic_matmul", {}, {memrefType, memrefType, memrefType}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - linalg_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments()))); + linalg_generic_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments()))); f.print(llvm::outs()); f.erase(); } // clang-format off -// CHECK-LABEL: func @linalg_conv_nhwc +// CHECK-LABEL: func @linalg_generic_conv_nhwc // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>, @@ -888,7 +888,7 @@ TEST_FUNC(linalg_matmul_test) { // CHECK: linalg.yield %[[a4]] : f32 // CHECK: }: memref, memref, memref // clang-format on -TEST_FUNC(linalg_conv_nhwc) { +TEST_FUNC(linalg_generic_conv_nhwc) { using namespace edsc; using namespace edsc::ops; @@ -897,12 +897,12 @@ TEST_FUNC(linalg_conv_nhwc) { MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize, ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); - auto f = makeFunction("linalg_conv_nhwc", {}, + auto f = makeFunction("linalg_generic_conv_nhwc", {}, {memrefType, memrefType, memrefType}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), + linalg_generic_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), /*strides=*/{3, 4}, /*dilations=*/{5, 6}); f.print(llvm::outs()); @@ -910,7 +910,7 @@ TEST_FUNC(linalg_conv_nhwc) { } // clang-format off -// CHECK-LABEL: func @linalg_dilated_conv_nhwc +// CHECK-LABEL: func @linalg_generic_dilated_conv_nhwc // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3 * 3 + d5 * 5, d4 * 4 + d6 * 6, d2)>, // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>, @@ -922,7 +922,7 @@ TEST_FUNC(linalg_conv_nhwc) { // CHECK: linalg.yield %[[a4]] : f32 // CHECK: }: memref, memref, memref // clang-format on -TEST_FUNC(linalg_dilated_conv_nhwc) { +TEST_FUNC(linalg_generic_dilated_conv_nhwc) { using namespace edsc; using namespace edsc::ops; @@ -931,12 +931,12 @@ TEST_FUNC(linalg_dilated_conv_nhwc) { MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize, ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); - auto f = makeFunction("linalg_dilated_conv_nhwc", {}, + auto f = makeFunction("linalg_generic_dilated_conv_nhwc", {}, {memrefType, memrefType, memrefType}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - linalg_dilated_conv_nhwc(makeValueHandles(f.getArguments()), + linalg_generic_dilated_conv_nhwc(makeValueHandles(f.getArguments()), /*depth_multiplier=*/7, /*strides=*/{3, 4}, /*dilations=*/{5, 6}); @@ -1019,11 +1019,11 @@ TEST_FUNC(linalg_tensors_test) { AffineExpr i, j; bindDims(&globalContext(), i, j); StructuredIndexed SA(A), SB(B), SC(tensorType); - linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); - linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); - linalg_pointwise_tanh(SA({i, j}), SC({i, j})); - Value o1 = linalg_matmul(A, B, tensorType)->getResult(0); - linalg_matmul(A, B, ValueHandle(o1), tensorType); + linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j})); + Value o1 = linalg_generic_matmul(A, B, tensorType)->getResult(0); + linalg_generic_matmul(A, B, ValueHandle(o1), tensorType); f.print(llvm::outs()); f.erase(); @@ -1067,9 +1067,9 @@ TEST_FUNC(memref_vector_matmul_test) { ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); auto contractionBuilder = [](ArrayRef args) { assert(args.size() == 3 && "expected 3 block arguments"); - (linalg_yield(vector_matmul(args[0], args[1], args[2]))); + (linalg_yield(vector_contraction_matmul(args[0], args[1], args[2]))); }; - linalg_matmul(A, B, C, contractionBuilder); + linalg_generic_matmul(A, B, C, contractionBuilder); f.print(llvm::outs()); f.erase();