Add 2D case for convolution (#693)

pull/745/head snapshot-20220408.376
gpetters94 2022-04-08 00:47:57 -04:00 committed by GitHub
parent fa0b24a73c
commit 9ec0683e92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 301 additions and 39 deletions

View File

@ -18,6 +18,8 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"TableBatchEmbeddingModule_basic",
"MobilenetV2Module_basic",
"MobilenetV3Module_basic",
"ConvolutionModule3D_basic",
"ConvolutionModule1D_basic",
}
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
@ -156,4 +158,5 @@ TOSA_PASS_SET = {
"GeluBackwardModule_basic",
"ElementwiseNeIntScalarModule_basic",
"ElementwiseNeFloatTensorModule_basic",
"ConvolutionModule2DStatic_basic",
}

View File

@ -2728,6 +2728,68 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
}];
}
def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
ListOfTorchIntType:$stride,
ListOfTorchIntType:$padding,
ListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
ListOfTorchIntType:$output_padding,
Torch_IntType:$groups
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvolutionOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenConvolutionOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}
def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideable", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
ListOfTorchIntType:$stride,
ListOfTorchIntType:$padding,
ListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
ListOfTorchIntType:$output_padding,
Torch_IntType:$groups
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvolutionOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenConvolutionOverrideableOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}
def Torch_AtenNativeBatchNormOp : Torch_Op<"aten.native_batch_norm", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -449,58 +449,65 @@ public:
} // namespace
namespace {
class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenConv2dOp op, OpAdaptor adaptor,
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value input = adaptor.input(); /* in form of N*C*H*W */
Value weight = adaptor.weight(); /* in form of F*C*H*W */
Value groups = adaptor.groups();
Type elementType =
input.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>())
return op.emitError("unimplemented: non-floating point type");
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
if (inRank != 4)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D convolution currently supported");
Type intType = IntegerType::get(context, 64);
auto castIndexToInt = [&](Value v) {
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
};
Value N = getDimOp(rewriter, loc, input, 0);
Value Hin = getDimOp(rewriter, loc, input, 2);
Value Win = getDimOp(rewriter, loc, input, 3);
Value F = getDimOp(rewriter, loc, weight, 0);
Value weightH = getDimOp(rewriter, loc, weight, 2);
Value weightW = getDimOp(rewriter, loc, weight, 3);
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
SmallVector<int64_t> paddingInts;
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) {
return rewriter.notifyMatchFailure(
op, "only support constant padding values");
}
SmallVector<int64_t, 2> strideInts;
SmallVector<int64_t> strideInts;
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
SmallVector<int64_t, 2> dilationInts;
SmallVector<int64_t> dilationInts;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
Value c1 =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, groups, c1);
rewriter.create<cf::AssertOp>(
loc, groupEqual1, rewriter.getStringAttr("expect groups to be 1"));
Value N = getDimOp(rewriter, loc, input, 0);
SmallVector<Value> inDims;
for (size_t i = 2; i < inRank; i++)
inDims.push_back(getDimOp(rewriter, loc, input, i));
Value F = getDimOp(rewriter, loc, weight, 0);
SmallVector<Value> weightDims;
for (size_t i = 2; i < inRank; i++)
weightDims.push_back(getDimOp(rewriter, loc, weight, i));
// Guard unused values (transposed, groups)
int64_t group_size;
if (!matchPattern(op.groups(), m_TorchConstantInt(&group_size)) ||
group_size != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: only group size of 1 supported");
bool transposed = true;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) ||
transposed)
return rewriter.notifyMatchFailure(
op, "unimplemented: only non-transposed convolution supported");
// Pad the input tensor according to padding.
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
@ -516,15 +523,14 @@ public:
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
Value Hout = torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, Hin, paddingIntValues[0], dilationIntValues[0],
castIndexToInt(weightH), strideIntValues[0]);
Value Wout = torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
castIndexToInt(weightW), strideIntValues[1]);
SmallVector<Value> outDims{N, F};
for (size_t i = 0; i < inRank - 2; i++)
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
castIndexToInt(weightDims[i]), strideIntValues[i]));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, F, Hout, Wout}, elementType);
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, outDims, elementType);
Value bias = adaptor.bias();
Value biasInitTensor;
@ -559,14 +565,17 @@ public:
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
Value conv2d =
// TODO: add 1D and 3D case
Value conv =
rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
biasInitTensor, stridesAttr, dilationAttr)
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
};
@ -584,6 +593,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenLinearOp>();
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
target.addIllegalOp<AtenConv2dOp>();
patterns.add<ConvertAtenConv2dOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
}

View File

@ -1677,8 +1677,8 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
}
template <>
LogicalResult ConvertAtenOp<AtenConv2dOp>::matchAndRewrite(
AtenConv2dOp op, OpAdaptor adaptor,
LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto input = adaptor.input();
@ -1692,13 +1692,19 @@ LogicalResult ConvertAtenOp<AtenConv2dOp>::matchAndRewrite(
if (!inputTy || !weightTy || !outputTy)
return op.emitError(
"Input, weight and output to Conv2d must be ranked tensors");
"Input, weight and output to Convolution must be ranked tensors");
auto inputElemTy = inputTy.getElementType();
auto weightElemTy = weightTy.getElementType();
auto inputShape = inputTy.getShape();
auto weightShape = weightTy.getShape();
if (inputTy.getRank() != 4)
return op.emitError("Unimplemented: only 2D convolutions supported");
if (!weightTy.hasStaticShape())
return op.emitError("Unimplemented: TOSA only supports static weight");
// Bias is optional. TOSA mandates a zero tensor here, so construct one if
// required.
auto bias = adaptor.bias();
@ -3140,7 +3146,7 @@ public:
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
INSERT_ATENOP_PATTERN(AtenConv2dOp);
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReshapeOp);
INSERT_ATENOP_PATTERN(AtenBatchNormOp);

View File

@ -737,6 +737,47 @@ public:
};
} // namespace
// Decompose aten.convolution_overrideable to aten.convolution
namespace {
class DecomposeAtenConvolutionOverrideableOp
: public OpRewritePattern<AtenConvolutionOverrideableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvolutionOverrideableOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
op.stride(), op.padding(), op.dilation(), op.transposed(),
op.output_padding(), op.groups());
return success();
}
};
} // namespace
// Decompose aten.conv2d to aten.convolution
namespace {
class DecomposeAtenConv2dOp : public OpRewritePattern<AtenConv2dOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConv2dOp op,
PatternRewriter &rewriter) const override {
Value emptyList = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
SmallVector<Value>());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
op.stride(), op.padding(), op.dilation(), cstFalse, emptyList,
op.groups());
return success();
}
};
} // namespace
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
namespace {
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
@ -1674,6 +1715,10 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenConvolutionOverrideableOp>();
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
target.addIllegalOp<AtenConv2dOp>();
patterns.add<DecomposeAtenConv2dOp>(context);
patterns.add<DecomposeAtenArangeOp>(context);
target.addIllegalOp<AtenArangeOp>();
patterns.add<DecomposeAtenArangeStartOp>(context);

View File

@ -549,7 +549,8 @@ ChangeResult TypeAnalyzer::visitOperation(
}
// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp>(op)) {
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
AtenConvolutionOverrideableOp>(op)) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(

View File

@ -2295,6 +2295,10 @@ module {
} : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool
return %1 : !torch.bool
}
func @"__torch_mlir_shape_fn.aten.convolution"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.conv_output_size(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.batch_norm"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}

View File

@ -762,6 +762,9 @@ def atentopk(self: List[int], k: int, dim: int = -1, largest: bool = True, so
def atenconv2d(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]:
return upstream_shape_helpers.conv2d(input, weight, bias, stride, padding, dilation, groups)
def atenconvolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]:
return upstream_shape_helpers.conv_output_size(input, weight, bias, stride, padding, dilation, groups)
def atenbatch_norm(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
# Torch's symbolic shape analysis is a bit looser about optional
# arguments than we are, so their batch_norm helper function works

View File

@ -308,6 +308,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
)
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
)

View File

@ -133,3 +133,130 @@ class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):
def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)
# ==============================================================================
class ConvolutionModule1D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[1],
padding=[0],
dilation=[1],
transposed=False,
output_padding=[0],
groups=1)
@register_test_case(module_factory=lambda: ConvolutionModule1D())
def ConvolutionModule1D_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10), torch.randn(3, 3, 2))
class ConvolutionModule2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1)
@register_test_case(module_factory=lambda: ConvolutionModule2D())
def ConvolutionModule2D_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class ConvolutionModule3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[1, 1, 1],
padding=[0, 0, 0],
dilation=[1, 1, 1],
transposed=False,
output_padding=[0, 0, 0],
groups=1)
@register_test_case(module_factory=lambda: ConvolutionModule3D())
def ConvolutionModule3D_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10, 10), torch.randn(3, 3, 2, 2, 2))
class ConvolutionModule2DStatic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 3, 10, 10], torch.float32, True),
([3, 3, 2, 2], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1)
@register_test_case(module_factory=lambda: ConvolutionModule2DStatic())
def ConvolutionModule2DStatic_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class ConvolutionModule2DStrided(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1)
@register_test_case(module_factory=lambda: ConvolutionModule2DStrided())
def ConvolutionModule2DStrided_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))