diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc index 5607dd215..0707b82d6 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.cpp.inc @@ -883,6 +883,26 @@ const Torch::BuildKernelMetadata &ConvolutionOp::getTorchBuildKernelMetadata() { return metadata; } +Torch::KernelMetadata Conv2dOp::getTorchKernelMetadata() { + return getTorchBuildKernelMetadata(); +} + +const Torch::BuildKernelMetadata &Conv2dOp::getTorchBuildKernelMetadata() { + using KVC = Torch::KernelValueConversion::BitMask; + static Torch::BuildKernelMetadata metadata = ([]() { + Torch::BuildKernelMetadata m; + m.kernelName = "aten::conv2d_overrideable"; + m.aliasKernelNames.push_back("aten::conv2d"); + m.promoteTrailingOutTensor = true; + m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "int"}); + m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone}); + m.addReturnTypes({"Tensor"}); + m.addReturnConversions({KVC::kImmutableTensor}); + return m; + })(); + return metadata; +} + Torch::KernelMetadata ConvolutionBackwardOp::getTorchKernelMetadata() { return getTorchBuildKernelMetadata(); } diff --git a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td index 479694124..712d7a7d3 100644 --- a/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td +++ b/include/npcomp/Dialect/ATen/IR/GeneratedATenOps.td @@ -494,6 +494,22 @@ def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterface ); } +def aten_Conv2dOp: aten_Op<"conv2d", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "Recognized op for kernel aten::conv2d_overrideable"; + let arguments = (ins + AnyTorchImmutableTensor:$input, + AnyTorchImmutableTensor:$weight, + AnyTorchOptionalImmutableTensor:$bias, + AnyTorchIntListType:$stride, + AnyTorchIntListType:$padding, + AnyTorchIntListType:$dilation, + AnyTorchIntType:$groups + ); + let results = (outs + AnyTorchImmutableTensor + ); +} + def aten_ConvolutionBackwardOp: aten_Op<"convolution_backward", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "Recognized op for kernel aten::convolution_backward_overrideable"; let arguments = (ins diff --git a/lib/Conversion/ATenToTCF/CoreOpConversionPatterns.cpp b/lib/Conversion/ATenToTCF/CoreOpConversionPatterns.cpp index edd0432da..8851dc426 100644 --- a/lib/Conversion/ATenToTCF/CoreOpConversionPatterns.cpp +++ b/lib/Conversion/ATenToTCF/CoreOpConversionPatterns.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "npcomp/Dialect/ATen/IR/ATenDialect.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/TCF/IR/TCFOps.h" using namespace mlir; @@ -66,6 +67,86 @@ public: } }; +/// The ATen Conv2dOp has seven arguments: +/// input, weight, bias, stride, padding, dilation, groups + +class ConvertATenConv2d : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(aten::Conv2dOp srcConv2dOp, + PatternRewriter &rewriter) const override { + auto results = srcConv2dOp.getOperation()->getResults(); + assert(srcConv2dOp.getNumOperands() == 7 && "expected seven (7) operands"); + assert(results.size() == 1 && "expected single result op"); + // TODO: Replace constant int-list constraints for stride, padding, and dilation; and, constant int constraint for groups. + auto strideOp = srcConv2dOp.stride().getDefiningOp(); + if (!strideOp) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "expected basicpy.build_list to drive stride input"); + } + if (strideOp.getNumOperands() != 2) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "expected stride length of 2"); + } + auto *strideOperand0Op = strideOp.getOperand(0).getDefiningOp(); + auto *strideOperand1Op = strideOp.getOperand(1).getDefiningOp(); + if (!matchPattern(strideOperand0Op, m_One()) + || !matchPattern(strideOperand1Op, m_One()) + ) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports stride == [1, 1]"); + } + auto paddingOp = srcConv2dOp.padding().getDefiningOp(); + if (!paddingOp) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "expected basicpy.build_list to drive padding input"); + } + if (paddingOp.getNumOperands() != 2) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "expected padding length of 2"); + } + auto *paddingOperand0Op = paddingOp.getOperand(0).getDefiningOp(); + auto *paddingOperand1Op = paddingOp.getOperand(1).getDefiningOp(); + if (!matchPattern(paddingOperand0Op, m_Zero()) + || !matchPattern(paddingOperand1Op, m_Zero()) + ) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports padding == [0, 0]"); + } + auto dilationOp = srcConv2dOp.dilation().getDefiningOp(); + if (!dilationOp) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "expected basicpy.build_list to drive dilation input"); + } + if (dilationOp.getNumOperands() != 2) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "expected dilation length of 2"); + } + auto *dilationOperand0Op = dilationOp.getOperand(0).getDefiningOp(); + auto *dilationOperand1Op = dilationOp.getOperand(1).getDefiningOp(); + if (!matchPattern(dilationOperand0Op, m_One()) + || !matchPattern(dilationOperand1Op, m_One()) + ) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports dilation == [1, 1]"); + } + if (!matchPattern(srcConv2dOp.groups(), m_One()) + ) { + return rewriter.notifyMatchFailure( + srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports groups == 1"); + } + auto tcfConvNCHWOp = rewriter.create( + srcConv2dOp.getLoc(), srcConv2dOp.getResult().getType(), srcConv2dOp.input(), + srcConv2dOp.weight()); + // TODO: Reference Torch Conv2D's bias flag to conditionally create TCF Add. + // (see https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) + auto tcfConvNCHWBiasOp = rewriter.create( + srcConv2dOp.getLoc(), srcConv2dOp.getResult().getType(), tcfConvNCHWOp.getResult(), + srcConv2dOp.bias()); + rewriter.replaceOp(srcConv2dOp, tcfConvNCHWBiasOp.getResult()); + return success(); + } +}; + } // namespace void mlir::NPCOMP::populateCoreATenToTCFPatterns( @@ -75,4 +156,5 @@ void mlir::NPCOMP::populateCoreATenToTCFPatterns( patterns.insert>( context); patterns.insert>(context); + patterns.insert(context); } diff --git a/test/Conversion/ATenToTCF/core_op_conversions.mlir b/test/Conversion/ATenToTCF/core_op_conversions.mlir index 38c301dca..f7a3f3724 100644 --- a/test/Conversion/ATenToTCF/core_op_conversions.mlir +++ b/test/Conversion/ATenToTCF/core_op_conversions.mlir @@ -1,5 +1,18 @@ // RUN: npcomp-opt <%s -convert-aten-to-tcf | FileCheck %s --dump-input=fail +// CHECK-LABEL: @conv2d +func @conv2d(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: %[[CONV2D_RESULT:.*]] = tcf.conv_2d_nchw %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: tcf.add %[[CONV2D_RESULT]], %arg2 : (tensor, tensor) -> tensor + %c0_i64 = constant 0 : i64 + %c1_i64 = constant 1 : i64 + %0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType + %2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor, tensor, tensor, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor + return %3 : tensor +} + // CHECK-LABEL: @binary_elementwise_ops // NOTE: These are all template expanded, so just testing an examplar op and // special cases. diff --git a/test/Dialect/ATen/recognize_aten_kernels.mlir b/test/Dialect/ATen/recognize_aten_kernels.mlir index 5963e2094..27986328f 100644 --- a/test/Dialect/ATen/recognize_aten_kernels.mlir +++ b/test/Dialect/ATen/recognize_aten_kernels.mlir @@ -86,6 +86,18 @@ func @convolution_backward( return %0#0, %0#1, %0#2 : !basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32> } +// ----- +// CHECK-LABEL: func @conv2d +// Contains a Tensor, Tensor, Tensor?, int[], int[] int[], int +func @conv2d(%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>, %arg3: !basicpy.ListType, %arg4: !basicpy.ListType, %arg5: !basicpy.ListType, %arg6: i64) -> !numpy.ndarray<*:!numpy.any_dtype> { + // CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0 + // CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1 + // CHECK: %[[TARG2:.*]] = numpy.copy_to_tensor %arg2 + // CHECK: %[[TRESULT:.*]] = "aten.conv2d"(%[[TARG0]], %[[TARG1]], %[[TARG2]], %arg3, %arg4, %arg5, %arg6) : (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor<*x!numpy.any_dtype> + %0 = torch.kernel_call "aten::conv2d" %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6: (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} + return %0 : !numpy.ndarray<*:!numpy.any_dtype> +} + // ----- // CHECK-LABEL: func @copy_inplace // Mutable/in-place op conversion, dropping result. diff --git a/test/npcomp-run-mlir/aten/conv_2d_nchw.mlir b/test/npcomp-run-mlir/aten/conv_2d_nchw.mlir new file mode 100644 index 000000000..405fb3395 --- /dev/null +++ b/test/npcomp-run-mlir/aten/conv_2d_nchw.mlir @@ -0,0 +1,79 @@ +// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<2x1x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=BATCH + +// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<2x2x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=SAME_CHANNELS + +// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=DIFFERENT_CHANNELS + +// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=TINY_SQUARE + +// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x32x32xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x32x32xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=HUGE_SQUARE + +// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=ZERO_KH_KW + +// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=ZERO_H_W + +// BATCH: output #0: dense<0.000000e+00> : tensor<2x1x1x1xf32> + +// SAME_CHANNELS: output #0: dense<0.000000e+00> : tensor<1x2x1x1xf32> + +// DIFFERENT_CHANNELS: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32> + +// TINY_SQUARE: output #0: dense<0.000000e+00> : tensor<1x1x2x2xf32> + +// HUGE_SQUARE: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32> + +// ZERO_KH_KW: output #0: dense<0.000000e+00> : tensor<1x1x3x3xf32> + +// ZERO_H_W: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32> + +func @aten_conv_2d_nchw(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %c0_i64 = constant 0 : i64 + %c1_i64 = constant 1 : i64 + %0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType + %2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor, tensor, tensor, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor + return %3 : tensor +} diff --git a/test/npcomp-run-mlir/aten/invalid-conv_2d_nchw.mlir b/test/npcomp-run-mlir/aten/invalid-conv_2d_nchw.mlir new file mode 100644 index 000000000..896c9d28f --- /dev/null +++ b/test/npcomp-run-mlir/aten/invalid-conv_2d_nchw.mlir @@ -0,0 +1,36 @@ +// RUN: npcomp-opt --convert-aten-to-tcf %s | not npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x2x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=CHANNELS + +// RUN: npcomp-opt --convert-aten-to-tcf %s | not npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x3x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=HEIGHT + +// RUN: npcomp-opt --convert-aten-to-tcf %s | not npcomp-run-mlir \ +// RUN: -invoke aten_conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x3xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=WIDTH + +// CHANNELS: NPCOMP: aborting: input and filter in-channels must be equal +// HEIGHT: NPCOMP: aborting: input height must be greater than or equal to filter KH-dimension +// WIDTH: NPCOMP: aborting: input width must be greater than or equal to filter KW-dimension +func @aten_conv_2d_nchw(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %c0_i64 = constant 0 : i64 + %c1_i64 = constant 1 : i64 + %0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType + %2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor, tensor, tensor, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor + return %3 : tensor +}