Add lowering for _convolution.deprecated (#1259)

* Add lowering for _convolution.deprecated
pull/1260/head snapshot-20220822.573
Alex Tsao 2022-08-22 11:17:36 +08:00 committed by GitHub
parent 99fb4c8637
commit c38308f3ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 166 additions and 9 deletions

View File

@ -3444,6 +3444,40 @@ def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
}]; }];
} }
def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
Torch_BoolType:$benchmark,
Torch_BoolType:$deterministic,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_ConvolutionDeprecatedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 12, 1);
}
void Aten_ConvolutionDeprecatedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 12, 1);
}
}];
}
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [ def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -927,13 +927,14 @@ public:
}; };
} // namespace } // namespace
// Decompose aten.convolution_overrideable to aten.convolution // Decompose aten._convolution-like to aten.convolution
namespace { namespace {
class DecomposeAten_ConvolutionOp template<typename ConvolutionLikeOp>
: public OpRewritePattern<Aten_ConvolutionOp> { class DecomposeAten_ConvolutionLikeOp
: public OpRewritePattern<ConvolutionLikeOp> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern<ConvolutionLikeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ConvolutionOp op, LogicalResult matchAndRewrite(ConvolutionLikeOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenConvolutionOp>( rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
@ -2542,8 +2543,10 @@ public:
patterns.add<DecomposeAtenNativeBatchNormOp>(context); patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenConvolutionOverrideableOp>(); target.addIllegalOp<AtenConvolutionOverrideableOp>();
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context); patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
target.addIllegalOp<Aten_ConvolutionOp>(); target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
patterns.add<DecomposeAten_ConvolutionOp>(context); patterns.add<DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>,
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
context);
target.addIllegalOp<AtenConv2dOp>(); target.addIllegalOp<AtenConv2dOp>();
patterns.add<DecomposeAtenConv2dOp>(context); patterns.add<DecomposeAtenConv2dOp>(context);
patterns.add<DecomposeAtenArangeOp>(context); patterns.add<DecomposeAtenArangeOp>(context);

View File

@ -712,7 +712,7 @@ void TypeAnalysis::visitOperation(Operation *op,
// Promote the two dtypes assuming non-zero rank. // Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp, if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, AtenConvolutionOverrideableOp>(op)) { Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenConvolutionOverrideableOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(

View File

@ -6341,6 +6341,10 @@ module {
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int> %0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
func.func @"__torch_mlir_shape_fn.aten._convolution.deprecated"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list<int> {
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> { func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int> return %arg0 : !torch.list<int>
} }

View File

@ -940,7 +940,10 @@ def atenconvolution(input: List[int], weight: List[int], bias: Optional[List[
def aten_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: def aten_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
return atenconvolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) return atenconvolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
def aten_convolutiondeprecated(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]:
return atenconvolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
def atenflip(self: List[int], dims: List[int]) -> List[int]: def atenflip(self: List[int], dims: List[int]) -> List[int]:
return self return self

View File

@ -337,6 +337,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
emit("aten::flip : (Tensor, int[]) -> (Tensor)") emit("aten::flip : (Tensor, int[]) -> (Tensor)")
emit( emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)" "aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"

View File

@ -406,6 +406,118 @@ class _Convolution2DTF32Module(torch.nn.Module):
def _Convolution2DTF32Module_basic(module, tu: TestUtils): def _Convolution2DTF32Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=False,
cudnn_enabled=False)
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=True,
deterministic=False,
cudnn_enabled=False)
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=True,
cudnn_enabled=False)
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
weight,
bias=None,
stride=[3, 3],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
benchmark=False,
deterministic=False,
cudnn_enabled=True)
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class ConvolutionModule2DGroups(torch.nn.Module): class ConvolutionModule2DGroups(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()