mirror of https://github.com/llvm/torch-mlir
Add lowering for _convolution.deprecated (#1259)
* Add lowering for _convolution.deprecatedpull/1260/head snapshot-20220822.573
parent
99fb4c8637
commit
c38308f3ef
|
@ -3444,6 +3444,40 @@ def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$input,
|
||||||
|
AnyTorchTensorType:$weight,
|
||||||
|
AnyTorchOptionalTensorType:$bias,
|
||||||
|
AnyTorchListOfTorchIntType:$stride,
|
||||||
|
AnyTorchListOfTorchIntType:$padding,
|
||||||
|
AnyTorchListOfTorchIntType:$dilation,
|
||||||
|
Torch_BoolType:$transposed,
|
||||||
|
AnyTorchListOfTorchIntType:$output_padding,
|
||||||
|
Torch_IntType:$groups,
|
||||||
|
Torch_BoolType:$benchmark,
|
||||||
|
Torch_BoolType:$deterministic,
|
||||||
|
Torch_BoolType:$cudnn_enabled
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult Aten_ConvolutionDeprecatedOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 12, 1);
|
||||||
|
}
|
||||||
|
void Aten_ConvolutionDeprecatedOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 12, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
|
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -927,13 +927,14 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Decompose aten.convolution_overrideable to aten.convolution
|
// Decompose aten._convolution-like to aten.convolution
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAten_ConvolutionOp
|
template<typename ConvolutionLikeOp>
|
||||||
: public OpRewritePattern<Aten_ConvolutionOp> {
|
class DecomposeAten_ConvolutionLikeOp
|
||||||
|
: public OpRewritePattern<ConvolutionLikeOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern<ConvolutionLikeOp>::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(Aten_ConvolutionOp op,
|
LogicalResult matchAndRewrite(ConvolutionLikeOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
||||||
|
@ -2542,8 +2543,10 @@ public:
|
||||||
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
||||||
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
||||||
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
|
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
|
||||||
target.addIllegalOp<Aten_ConvolutionOp>();
|
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
|
||||||
patterns.add<DecomposeAten_ConvolutionOp>(context);
|
patterns.add<DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>,
|
||||||
|
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
|
||||||
|
context);
|
||||||
target.addIllegalOp<AtenConv2dOp>();
|
target.addIllegalOp<AtenConv2dOp>();
|
||||||
patterns.add<DecomposeAtenConv2dOp>(context);
|
patterns.add<DecomposeAtenConv2dOp>(context);
|
||||||
patterns.add<DecomposeAtenArangeOp>(context);
|
patterns.add<DecomposeAtenArangeOp>(context);
|
||||||
|
|
|
@ -712,7 +712,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
|
|
||||||
// Promote the two dtypes assuming non-zero rank.
|
// Promote the two dtypes assuming non-zero rank.
|
||||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||||
Aten_ConvolutionOp, AtenConvolutionOverrideableOp>(op)) {
|
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenConvolutionOverrideableOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||||
|
|
|
@ -6341,6 +6341,10 @@ module {
|
||||||
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
|
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func.func @"__torch_mlir_shape_fn.aten._convolution.deprecated"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list<int> {
|
||||||
|
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
return %arg0 : !torch.list<int>
|
return %arg0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
|
|
@ -940,7 +940,10 @@ def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[
|
||||||
|
|
||||||
def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
|
def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
|
||||||
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
|
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
|
||||||
|
|
||||||
|
def aten〇_convolution〇deprecated(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]:
|
||||||
|
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
|
||||||
|
|
||||||
def aten〇flip(self: List[int], dims: List[int]) -> List[int]:
|
def aten〇flip(self: List[int], dims: List[int]) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -337,6 +337,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||||
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||||
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
|
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
|
||||||
|
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
|
||||||
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
||||||
|
|
|
@ -406,6 +406,118 @@ class _Convolution2DTF32Module(torch.nn.Module):
|
||||||
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
|
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||||
|
|
||||||
|
class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, inputVec, weight):
|
||||||
|
return torch.ops.aten._convolution(inputVec,
|
||||||
|
weight,
|
||||||
|
bias=None,
|
||||||
|
stride=[3, 3],
|
||||||
|
padding=[2, 2],
|
||||||
|
dilation=[1, 1],
|
||||||
|
transposed=False,
|
||||||
|
output_padding=[0, 0],
|
||||||
|
groups=1,
|
||||||
|
benchmark=False,
|
||||||
|
deterministic=False,
|
||||||
|
cudnn_enabled=False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
|
||||||
|
def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||||
|
|
||||||
|
class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, inputVec, weight):
|
||||||
|
return torch.ops.aten._convolution(inputVec,
|
||||||
|
weight,
|
||||||
|
bias=None,
|
||||||
|
stride=[3, 3],
|
||||||
|
padding=[2, 2],
|
||||||
|
dilation=[1, 1],
|
||||||
|
transposed=False,
|
||||||
|
output_padding=[0, 0],
|
||||||
|
groups=1,
|
||||||
|
benchmark=True,
|
||||||
|
deterministic=False,
|
||||||
|
cudnn_enabled=False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
|
||||||
|
def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||||
|
|
||||||
|
class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, inputVec, weight):
|
||||||
|
return torch.ops.aten._convolution(inputVec,
|
||||||
|
weight,
|
||||||
|
bias=None,
|
||||||
|
stride=[3, 3],
|
||||||
|
padding=[2, 2],
|
||||||
|
dilation=[1, 1],
|
||||||
|
transposed=False,
|
||||||
|
output_padding=[0, 0],
|
||||||
|
groups=1,
|
||||||
|
benchmark=False,
|
||||||
|
deterministic=True,
|
||||||
|
cudnn_enabled=False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
|
||||||
|
def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||||
|
|
||||||
|
class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, inputVec, weight):
|
||||||
|
return torch.ops.aten._convolution(inputVec,
|
||||||
|
weight,
|
||||||
|
bias=None,
|
||||||
|
stride=[3, 3],
|
||||||
|
padding=[2, 2],
|
||||||
|
dilation=[1, 1],
|
||||||
|
transposed=False,
|
||||||
|
output_padding=[0, 0],
|
||||||
|
groups=1,
|
||||||
|
benchmark=False,
|
||||||
|
deterministic=False,
|
||||||
|
cudnn_enabled=True)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
|
||||||
|
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||||
|
|
||||||
class ConvolutionModule2DGroups(torch.nn.Module):
|
class ConvolutionModule2DGroups(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue