mirror of https://github.com/llvm/torch-mlir
Import ATen conv2d conversion and test (#180)
* Import ATen conv2d conversion and test This is a first attempt at expanding ATen-to-TCF conversion for the conv2d operator. Eventually, this will come in use when lowering a high-level conv-based model.pull/187/head
parent
58c7030104
commit
4fd9b4afb5
|
@ -883,6 +883,26 @@ const Torch::BuildKernelMetadata &ConvolutionOp::getTorchBuildKernelMetadata() {
|
|||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata Conv2dOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &Conv2dOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::conv2d_overrideable";
|
||||
m.aliasKernelNames.push_back("aten::conv2d");
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "int"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kImmutableTensor, KVC::kNone, KVC::kNone, KVC::kNone, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata ConvolutionBackwardOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
|
|
@ -494,6 +494,22 @@ def aten_ConvolutionOp: aten_Op<"convolution", [NoSideEffect, DeclareOpInterface
|
|||
);
|
||||
}
|
||||
|
||||
def aten_Conv2dOp: aten_Op<"conv2d", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::conv2d_overrideable";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$input,
|
||||
AnyTorchImmutableTensor:$weight,
|
||||
AnyTorchOptionalImmutableTensor:$bias,
|
||||
AnyTorchIntListType:$stride,
|
||||
AnyTorchIntListType:$padding,
|
||||
AnyTorchIntListType:$dilation,
|
||||
AnyTorchIntType:$groups
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
def aten_ConvolutionBackwardOp: aten_Op<"convolution_backward", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::convolution_backward_overrideable";
|
||||
let arguments = (ins
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -66,6 +67,86 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// The ATen Conv2dOp has seven arguments:
|
||||
/// input, weight, bias, stride, padding, dilation, groups
|
||||
|
||||
class ConvertATenConv2d : public OpRewritePattern<aten::Conv2dOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(aten::Conv2dOp srcConv2dOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto results = srcConv2dOp.getOperation()->getResults();
|
||||
assert(srcConv2dOp.getNumOperands() == 7 && "expected seven (7) operands");
|
||||
assert(results.size() == 1 && "expected single result op");
|
||||
// TODO: Replace constant int-list constraints for stride, padding, and dilation; and, constant int constraint for groups.
|
||||
auto strideOp = srcConv2dOp.stride().getDefiningOp<Basicpy::BuildListOp>();
|
||||
if (!strideOp) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "expected basicpy.build_list to drive stride input");
|
||||
}
|
||||
if (strideOp.getNumOperands() != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "expected stride length of 2");
|
||||
}
|
||||
auto *strideOperand0Op = strideOp.getOperand(0).getDefiningOp();
|
||||
auto *strideOperand1Op = strideOp.getOperand(1).getDefiningOp();
|
||||
if (!matchPattern(strideOperand0Op, m_One())
|
||||
|| !matchPattern(strideOperand1Op, m_One())
|
||||
) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports stride == [1, 1]");
|
||||
}
|
||||
auto paddingOp = srcConv2dOp.padding().getDefiningOp<Basicpy::BuildListOp>();
|
||||
if (!paddingOp) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "expected basicpy.build_list to drive padding input");
|
||||
}
|
||||
if (paddingOp.getNumOperands() != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "expected padding length of 2");
|
||||
}
|
||||
auto *paddingOperand0Op = paddingOp.getOperand(0).getDefiningOp();
|
||||
auto *paddingOperand1Op = paddingOp.getOperand(1).getDefiningOp();
|
||||
if (!matchPattern(paddingOperand0Op, m_Zero())
|
||||
|| !matchPattern(paddingOperand1Op, m_Zero())
|
||||
) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports padding == [0, 0]");
|
||||
}
|
||||
auto dilationOp = srcConv2dOp.dilation().getDefiningOp<Basicpy::BuildListOp>();
|
||||
if (!dilationOp) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "expected basicpy.build_list to drive dilation input");
|
||||
}
|
||||
if (dilationOp.getNumOperands() != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "expected dilation length of 2");
|
||||
}
|
||||
auto *dilationOperand0Op = dilationOp.getOperand(0).getDefiningOp();
|
||||
auto *dilationOperand1Op = dilationOp.getOperand(1).getDefiningOp();
|
||||
if (!matchPattern(dilationOperand0Op, m_One())
|
||||
|| !matchPattern(dilationOperand1Op, m_One())
|
||||
) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports dilation == [1, 1]");
|
||||
}
|
||||
if (!matchPattern(srcConv2dOp.groups(), m_One())
|
||||
) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
srcConv2dOp, "aten.conv2d to tcf.conv_2d_nchw currently only supports groups == 1");
|
||||
}
|
||||
auto tcfConvNCHWOp = rewriter.create<tcf::ConvNCHWOp>(
|
||||
srcConv2dOp.getLoc(), srcConv2dOp.getResult().getType(), srcConv2dOp.input(),
|
||||
srcConv2dOp.weight());
|
||||
// TODO: Reference Torch Conv2D's bias flag to conditionally create TCF Add.
|
||||
// (see https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d)
|
||||
auto tcfConvNCHWBiasOp = rewriter.create<tcf::AddOp>(
|
||||
srcConv2dOp.getLoc(), srcConv2dOp.getResult().getType(), tcfConvNCHWOp.getResult(),
|
||||
srcConv2dOp.bias());
|
||||
rewriter.replaceOp(srcConv2dOp, tcfConvNCHWBiasOp.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::NPCOMP::populateCoreATenToTCFPatterns(
|
||||
|
@ -75,4 +156,5 @@ void mlir::NPCOMP::populateCoreATenToTCFPatterns(
|
|||
patterns.insert<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(
|
||||
context);
|
||||
patterns.insert<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
|
||||
patterns.insert<ConvertATenConv2d>(context);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,18 @@
|
|||
// RUN: npcomp-opt <%s -convert-aten-to-tcf | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: @conv2d
|
||||
func @conv2d(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
// CHECK: %[[CONV2D_RESULT:.*]] = tcf.conv_2d_nchw %arg0, %arg1 : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: tcf.add %[[CONV2D_RESULT]], %arg2 : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor<?x?x?x?xf32>
|
||||
return %3 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @binary_elementwise_ops
|
||||
// NOTE: These are all template expanded, so just testing an examplar op and
|
||||
// special cases.
|
||||
|
|
|
@ -86,6 +86,18 @@ func @convolution_backward(
|
|||
return %0#0, %0#1, %0#2 : !basicpy.NoneType, !numpy.ndarray<[4,16,3,3]:f32>, !numpy.ndarray<[4]:f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @conv2d
|
||||
// Contains a Tensor, Tensor, Tensor?, int[], int[] int[], int
|
||||
func @conv2d(%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>, %arg3: !basicpy.ListType, %arg4: !basicpy.ListType, %arg5: !basicpy.ListType, %arg6: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
// CHECK: %[[TARG0:.*]] = numpy.copy_to_tensor %arg0
|
||||
// CHECK: %[[TARG1:.*]] = numpy.copy_to_tensor %arg1
|
||||
// CHECK: %[[TARG2:.*]] = numpy.copy_to_tensor %arg2
|
||||
// CHECK: %[[TRESULT:.*]] = "aten.conv2d"(%[[TARG0]], %[[TARG1]], %[[TARG2]], %arg3, %arg4, %arg5, %arg6) : (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor<*x!numpy.any_dtype>
|
||||
%0 = torch.kernel_call "aten::conv2d" %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6: (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Tensor?", "int[]", "int[]", "int[]", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
return %0 : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @copy_inplace
|
||||
// Mutable/in-place op conversion, dropping result.
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<2x1x1x1xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=BATCH
|
||||
|
||||
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<2x2x1x1xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=SAME_CHANNELS
|
||||
|
||||
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=DIFFERENT_CHANNELS
|
||||
|
||||
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=TINY_SQUARE
|
||||
|
||||
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x32x32xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x32x32xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=HUGE_SQUARE
|
||||
|
||||
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=ZERO_KH_KW
|
||||
|
||||
// RUN: npcomp-opt --convert-aten-to-tcf %s | npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=ZERO_H_W
|
||||
|
||||
// BATCH: output #0: dense<0.000000e+00> : tensor<2x1x1x1xf32>
|
||||
|
||||
// SAME_CHANNELS: output #0: dense<0.000000e+00> : tensor<1x2x1x1xf32>
|
||||
|
||||
// DIFFERENT_CHANNELS: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32>
|
||||
|
||||
// TINY_SQUARE: output #0: dense<0.000000e+00> : tensor<1x1x2x2xf32>
|
||||
|
||||
// HUGE_SQUARE: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32>
|
||||
|
||||
// ZERO_KH_KW: output #0: dense<0.000000e+00> : tensor<1x1x3x3xf32>
|
||||
|
||||
// ZERO_H_W: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32>
|
||||
|
||||
func @aten_conv_2d_nchw(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor<?x?x?x?xf32>
|
||||
return %3 : tensor<?x?x?x?xf32>
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
// RUN: npcomp-opt --convert-aten-to-tcf %s | not npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x2x2x2xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=CHANNELS
|
||||
|
||||
// RUN: npcomp-opt --convert-aten-to-tcf %s | not npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x3x2xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=HEIGHT
|
||||
|
||||
// RUN: npcomp-opt --convert-aten-to-tcf %s | not npcomp-run-mlir \
|
||||
// RUN: -invoke aten_conv_2d_nchw \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x2x3xf32>" \
|
||||
// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=WIDTH
|
||||
|
||||
// CHANNELS: NPCOMP: aborting: input and filter in-channels must be equal
|
||||
// HEIGHT: NPCOMP: aborting: input height must be greater than or equal to filter KH-dimension
|
||||
// WIDTH: NPCOMP: aborting: input width must be greater than or equal to filter KW-dimension
|
||||
func @aten_conv_2d_nchw(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%3 = "aten.conv2d"(%arg0, %arg1, %arg2, %0, %1, %2, %c1_i64) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64) -> tensor<?x?x?x?xf32>
|
||||
return %3 : tensor<?x?x?x?xf32>
|
||||
}
|
Loading…
Reference in New Issue