mirror of https://github.com/llvm/torch-mlir
Add lowering for _convolution.deprecated (#1259)
* Add lowering for _convolution.deprecatedpull/1260/head snapshot-20220822.573
parent
99fb4c8637
commit
c38308f3ef
|
@ -3444,6 +3444,40 @@ def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$transposed,
|
||||
AnyTorchListOfTorchIntType:$output_padding,
|
||||
Torch_IntType:$groups,
|
||||
Torch_BoolType:$benchmark,
|
||||
Torch_BoolType:$deterministic,
|
||||
Torch_BoolType:$cudnn_enabled
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten_ConvolutionDeprecatedOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 12, 1);
|
||||
}
|
||||
void Aten_ConvolutionDeprecatedOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 12, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -927,13 +927,14 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.convolution_overrideable to aten.convolution
|
||||
// Decompose aten._convolution-like to aten.convolution
|
||||
namespace {
|
||||
class DecomposeAten_ConvolutionOp
|
||||
: public OpRewritePattern<Aten_ConvolutionOp> {
|
||||
template<typename ConvolutionLikeOp>
|
||||
class DecomposeAten_ConvolutionLikeOp
|
||||
: public OpRewritePattern<ConvolutionLikeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_ConvolutionOp op,
|
||||
using OpRewritePattern<ConvolutionLikeOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ConvolutionLikeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
||||
|
@ -2542,8 +2543,10 @@ public:
|
|||
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
||||
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
||||
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
|
||||
target.addIllegalOp<Aten_ConvolutionOp>();
|
||||
patterns.add<DecomposeAten_ConvolutionOp>(context);
|
||||
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
|
||||
patterns.add<DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>,
|
||||
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenConv2dOp>();
|
||||
patterns.add<DecomposeAtenConv2dOp>(context);
|
||||
patterns.add<DecomposeAtenArangeOp>(context);
|
||||
|
|
|
@ -712,7 +712,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
|
||||
// Promote the two dtypes assuming non-zero rank.
|
||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||
Aten_ConvolutionOp, AtenConvolutionOverrideableOp>(op)) {
|
||||
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenConvolutionOverrideableOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
|
|
|
@ -6341,6 +6341,10 @@ module {
|
|||
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten._convolution.deprecated"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list<int> {
|
||||
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
return %arg0 : !torch.list<int>
|
||||
}
|
||||
|
|
|
@ -940,7 +940,10 @@ def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[
|
|||
|
||||
def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
|
||||
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
|
||||
|
||||
|
||||
def aten〇_convolution〇deprecated(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]:
|
||||
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
|
||||
|
||||
def aten〇flip(self: List[int], dims: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
|
|
|
@ -337,6 +337,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
|
||||
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
|
||||
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
||||
|
|
|
@ -406,6 +406,118 @@ class _Convolution2DTF32Module(torch.nn.Module):
|
|||
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=False,
|
||||
cudnn_enabled=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
|
||||
def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=True,
|
||||
deterministic=False,
|
||||
cudnn_enabled=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
|
||||
def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=True,
|
||||
cudnn_enabled=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
|
||||
def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=False,
|
||||
cudnn_enabled=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
|
||||
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
class ConvolutionModule2DGroups(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue