[mlir][tosa] Migrate tosa to more efficient linalg.conv

Existing linalg.conv2d is not well optimized for performance. Changed to a
version that is more aligned for optimziation. Include the corresponding
transposes to use this optimized version.

This also splits the conv and depthwise conv into separate implementations
to avoid overly complex lowerings.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D107504
This commit is contained in:
Rob Suderman
2021-08-11 11:05:08 -07:00
parent c1a8f12873
commit 7de439b2be
5 changed files with 388 additions and 329 deletions

View File

@@ -144,49 +144,39 @@ def dot(
implements(ContractionOpInterface)
C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
@linalg_structured_op
def conv_2d_input_nhwc_filter_ohwi_poly(
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC),
O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True),
def conv_2d_nchw(
I=TensorDef(T1, S.N, S.C, S.IH, S.IW),
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
"""Performs a 2-D convolution.
"""Performs 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic)
O[D.n, D.oh, D.ow, D.oc] += cast(
U, I[D.n,
D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW,
D.ic]) * cast(U, K[D.oc, D.kh, D.kw, D.ic])
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.f, D.oh, D.ow] += cast(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
]) * cast(U, K[D.f, D.c, D.kh, D.kw])
@linalg_structured_op
def conv_2d_input_nhwc_filter_ohwi_poly_q(
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC),
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True),
def conv_2d_nhwc_hwcf(
I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
"""Performs a 2-D quantized convolution.
"""Performs 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. Includes zero point
adjustment for quantization.
them to the same data type as the accumulator/output.
"""
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic)
O[D.n, D.oh, D.ow, D.oc] += ((cast(
U, I[D.n,
D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW,
D.ic]) - cast(U, IZp)) *
(cast(U, K[D.oc, D.kh, D.kw, D.ic]) - cast(U, KZp)))
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.f] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c
]) * cast(U, K[D.kh, D.kw, D.c, D.f])
@linalg_structured_op
def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
@@ -206,24 +196,27 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
D.c]) * cast(U, K[D.kh, D.kw, D.c])
@linalg_structured_op
def conv_2d_nchw(
I=TensorDef(T1, S.N, S.C, S.IH, S.IW),
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True),
def conv_2d_nhwc_hwcf_q(
I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
"""Performs 2-D convolution.
"""Performs 2-D convolution with zero point offsets.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
them to the same data type as the accumulator/output. This includes the zero
point offsets common to quantized operations.
"""
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.f, D.oh, D.ow] += cast(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
]) * cast(U, K[D.f, D.c, D.kh, D.kw])
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.f] += (cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c
]) - cast(U, IZp)) * (cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp))
def depthwise_conv2D_nchw( #TODO: Fix name
@linalg_structured_op
def depthwise_conv2D_nchw(
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
@@ -239,8 +232,8 @@ def depthwise_conv2D_nchw( #TODO: Fix name
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm])
def depthwise_conv2D_nchw_q( #TODO: Fix name
@linalg_structured_op
def depthwise_conv2D_nchw_q(
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
IZp=ScalarDef(I32),