Import ATen conv2d conversion and test (#180)

* Import ATen conv2d conversion and test

This is a first attempt at expanding ATen-to-TCF conversion for the
conv2d operator. Eventually, this will come in use when lowering a
high-level conv-based model.
pull/187/head
Aaron Arthurs 2021-03-12 19:21:16 -06:00 committed by GitHub
parent 58c7030104
commit 4fd9b4afb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 258 additions and 0 deletions

View File

@ -883,6 +883,26 @@ const Torch::BuildKernelMetadata &ConvolutionOp::getTorchBuildKernelMetadata() {
return metadata;
}
Torch::KernelMetadata Conv2dOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &Conv2dOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::conv2d_overrideable";
m.aliasKernelNames.push_back("aten::conv2d");
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "int"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
Torch::KernelMetadata ConvolutionBackwardOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}

View File

@ -494,6 +494,22 @@ def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterface
);
}
def aten_Conv2dOp: aten_Op<"conv2d", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::conv2d_overrideable";
let arguments = (ins
AnyTorchImmutableTensor:$input,
AnyTorchImmutableTensor:$weight,
AnyTorchOptionalImmutableTensor:$bias,
AnyTorchIntListType:$stride,
AnyTorchIntListType:$padding,
AnyTorchIntListType:$dilation,
AnyTorchIntType:$groups
);
let results = (outs
AnyTorchImmutableTensor
);
}
def aten_ConvolutionBackwardOp: aten_Op<"convolution_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::convolution_backward_overrideable";
let arguments = (ins

View File

@ -12,6 +12,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
using namespace mlir;
@ -66,6 +67,86 @@ public:
}
};
/// The ATen Conv2dOp has seven arguments:
/// input, weight, bias, stride, padding, dilation, groups
class ConvertATenConv2d : public OpRewritePattern<aten::Conv2dOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(aten::Conv2dOp srcConv2dOp,
PatternRewriter &rewriter) const override {
auto results = srcConv2dOp.getOperation()->getResults();
assert(srcConv2dOp.getNumOperands() == 7 && "expected seven (7) operands");
assert(results.size() == 1 && "expected single result op");
// TODO: Replace constant int-list constraints for stride, padding, and dilation; and, constant int constraint for groups.
auto strideOp = srcConv2dOp.stride().getDefiningOp<Basicpy::BuildListOp>();
if (!strideOp) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "expected basicpy.build_list to drive stride input");
}
if (strideOp.getNumOperands() != 2) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "expected stride length of 2");
}
auto *strideOperand0Op = strideOp.getOperand(0).getDefiningOp();
auto *strideOperand1Op = strideOp.getOperand(1).getDefiningOp();
if (!matchPattern(strideOperand0Op, m_One())
|| !matchPattern(strideOperand1Op, m_One())
) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports stride == [1, 1]");
}
auto paddingOp = srcConv2dOp.padding().getDefiningOp<Basicpy::BuildListOp>();
if (!paddingOp) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "expected basicpy.build_list to drive padding input");
}
if (paddingOp.getNumOperands() != 2) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "expected padding length of 2");
}
auto *paddingOperand0Op = paddingOp.getOperand(0).getDefiningOp();
auto *paddingOperand1Op = paddingOp.getOperand(1).getDefiningOp();
if (!matchPattern(paddingOperand0Op, m_Zero())
|| !matchPattern(paddingOperand1Op, m_Zero())
) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports padding == [0, 0]");
}
auto dilationOp = srcConv2dOp.dilation().getDefiningOp<Basicpy::BuildListOp>();
if (!dilationOp) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "expected basicpy.build_list to drive dilation input");
}
if (dilationOp.getNumOperands() != 2) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "expected dilation length of 2");
}
auto *dilationOperand0Op = dilationOp.getOperand(0).getDefiningOp();
auto *dilationOperand1Op = dilationOp.getOperand(1).getDefiningOp();
if (!matchPattern(dilationOperand0Op, m_One())
|| !matchPattern(dilationOperand1Op, m_One())
) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports dilation == [1, 1]");
}
if (!matchPattern(srcConv2dOp.groups(), m_One())
) {
return rewriter.notifyMatchFailure(
srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports groups == 1");
}
auto tcfConvNCHWOp = rewriter.create<tcf::ConvNCHWOp>(
srcConv2dOp.getLoc(), srcConv2dOp.getResult().getType(), srcConv2dOp.input(),
srcConv2dOp.weight());
// TODO: Reference Torch Conv2D's bias flag to conditionally create TCF Add.
// (see https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d)
auto tcfConvNCHWBiasOp = rewriter.create<tcf::AddOp>(
srcConv2dOp.getLoc(), srcConv2dOp.getResult().getType(), tcfConvNCHWOp.getResult(),
srcConv2dOp.bias());
rewriter.replaceOp(srcConv2dOp, tcfConvNCHWBiasOp.getResult());
return success();
}
};
} // namespace
void mlir::NPCOMP::populateCoreATenToTCFPatterns(
@ -75,4 +156,5 @@ void mlir::NPCOMP::populateCoreATenToTCFPatterns(
patterns.insert<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(
context);
patterns.insert<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
patterns.insert<ConvertATenConv2d>(context);
}

View File

@ -1,5 +1,18 @@
// RUN: npcomp-opt <%s -convert-aten-to-tcf | FileCheck %s --dump-input=fail
// CHECK-LABEL: @conv2d
func @conv2d(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CONV2D_RESULT:.*]] = tcf.conv_2d_nchw %arg0, %arg1 : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: tcf.add %[[CONV2D_RESULT]], %arg2 : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
%c0_i64 = constant 0 : i64
%c1_i64 = constant 1 : i64
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor<?x?x?x?xf32>
return %3 : tensor<?x?x?x?xf32>
}
// CHECK-LABEL: @binary_elementwise_ops
// NOTE: These are all template expanded, so just testing an examplar op and
// special cases.

View File

@ -86,6 +86,18 @@ func @convolution_backward(
return %0#0, %0#1, %0#2 : !basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>
}
// -----
// CHECK-LABEL: func @conv2d
// Contains a Tensor, Tensor, Tensor?, int[], int[] int[], int
func @conv2d(%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>, %arg3: !basicpy.ListType, %arg4: !basicpy.ListType, %arg5: !basicpy.ListType, %arg6: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
// CHECK: %[[TARG2:.*]] = numpy.copy_to_tensor %arg2
// CHECK: %[[TRESULT:.*]] = "aten.conv2d"(%[[TARG0]], %[[TARG1]], %[[TARG2]], %arg3, %arg4, %arg5, %arg6) : (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor<*x!numpy.any_dtype>
%0 = torch.kernel_call "aten::conv2d" %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6: (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
}
// -----
// CHECK-LABEL: func @copy_inplace
// Mutable/in-place op conversion, dropping result.

View File

@ -0,0 +1,79 @@
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<2x1x1x1xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=BATCH
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<2x2x1x1xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=SAME_CHANNELS
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=DIFFERENT_CHANNELS
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=TINY_SQUARE
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<1x1x32x32xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x32x32xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=HUGE_SQUARE
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=ZERO_KH_KW
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=ZERO_H_W
// BATCH: output #0: dense<0.000000e+00> : tensor<2x1x1x1xf32>
// SAME_CHANNELS: output #0: dense<0.000000e+00> : tensor<1x2x1x1xf32>
// DIFFERENT_CHANNELS: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32>
// TINY_SQUARE: output #0: dense<0.000000e+00> : tensor<1x1x2x2xf32>
// HUGE_SQUARE: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32>
// ZERO_KH_KW: output #0: dense<0.000000e+00> : tensor<1x1x3x3xf32>
// ZERO_H_W: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32>
func @aten_conv_2d_nchw(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%c0_i64 = constant 0 : i64
%c1_i64 = constant 1 : i64
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor<?x?x?x?xf32>
return %3 : tensor<?x?x?x?xf32>
}

View File

@ -0,0 +1,36 @@
// RUN: npcomp-opt --convert-aten-to-tcf %s | not npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x2x2x2xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=CHANNELS
// RUN: npcomp-opt --convert-aten-to-tcf %s | not npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x3x2xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=HEIGHT
// RUN: npcomp-opt --convert-aten-to-tcf %s | not npcomp-run-mlir \
// RUN: -invoke aten_conv_2d_nchw \
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x3xf32>" \
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=WIDTH
// CHANNELS: NPCOMP: aborting: input and filter in-channels must be equal
// HEIGHT: NPCOMP: aborting: input height must be greater than or equal to filter KH-dimension
// WIDTH: NPCOMP: aborting: input width must be greater than or equal to filter KW-dimension
func @aten_conv_2d_nchw(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%c0_i64 = constant 0 : i64
%c1_i64 = constant 1 : i64
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor<?x?x?x?xf32>
return %3 : tensor<?x?x?x?xf32>
}