From e55fc4deb501bc2f52d3b7ab001dc9dc19a817a9 Mon Sep 17 00:00:00 2001 From: powderluv Date: Mon, 8 Aug 2022 22:59:57 -0700 Subject: [PATCH] Revert "E2E support for AtenRemainderScalarOp (#1119)" (#1190) This reverts commit 34e207eeb502109957bebee6654709bb89f2d026. --- e2e_testing/torchscript/xfail_sets.py | 5 +- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ------- .../TorchToLinalg/Uncategorized.cpp | 49 +++++--------- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 -- .../jit_ir/build_tools/shape_lib_gen.py | 3 - .../jit_ir/build_tools/torch_ods_gen.py | 1 - .../test_suite/elementwise.py | 64 ------------------- 8 files changed, 18 insertions(+), 134 deletions(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 6824e05a1..196a67e5b 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -178,10 +178,7 @@ TOSA_PASS_SET = { "TransposeIntNegDimsModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", - "_LogSoftmaxModuleStable_basic", - "ElementwiseRemainderScalarModule_Int_Float_basic", - "ElementwiseRemainderScalarModule_Float_basic", - "ElementwiseRemainderScalarModule_Int_basic", + "_LogSoftmaxModuleStable_basic", } LTC_XFAIL_SET = { diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 75c2d98fe..6d6459d5d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7883,30 +7883,6 @@ def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [ let hasFolder = 1; } -def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRemainderScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenRemainderScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 2453bb2d3..f11f8a52e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -803,23 +803,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value other = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, self, other); } - if (auto remScalar = dyn_cast(op)) { - Type newResultType = converter->convertType(remScalar.getType()) - .cast() - .getElementType(); - - Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); - Value other = convertScalarToDtype(b, loc, operands[1], newResultType); - Value result; - - if (newResultType.isa()) { - result = b.create(loc, self, other); - } else { - result = b.create(loc, self, other); - } - - return result; - } if (auto reciprocal = dyn_cast(op)) { Type dtype = converter->convertType(reciprocal.getType()) .cast() @@ -960,14 +943,14 @@ public: AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, - AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, - AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, - AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, - AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, - AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, - AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, - AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, - AtenLogicalOrOp, AtenTriuOp>(op)) + AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, + AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, + AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, + AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, + AtenTriuOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1698,14 +1681,14 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, - AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, - AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, - AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, - AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, - AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, - AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>(); + AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, + AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, + AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, + AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, + AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 43222b62b..97b7a4326 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -770,7 +770,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote LHS with scalar RHS. if (isa(op)) { + AtenRsubScalarOp, AtenLeakyReluOp>(op)) { auto lhs = operands[0]->getValue(); Value scalar = op->getOperand(1); auto knowledge = diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index f5e220050..610fdfa6c 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5441,10 +5441,6 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } - func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 83e67ea27..b4865c6c2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -478,9 +478,6 @@ def aten〇mul〇Scalar(self: List[int], other: float) -> List[int]: def aten〇div〇Scalar(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇remainder〇Scalar(self: List[int], other: float) -> List[int]: - return upstream_shape_functions.unary(self) - def aten〇floor_divide〇Scalar(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index c86aaf804..1e28b74b9 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -549,7 +549,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::eq.int : (int, int) -> (bool)", has_folder=True) emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) - emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::mul.int : (int, int) -> (int)", has_folder=True) diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 5cea1d4f7..536982c06 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1304,70 +1304,6 @@ class ElementwiseDivScalarModule(torch.nn.Module): def ElementwiseDivScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) -# ============================================================================== - - -class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.int64, True), - ]) - def forward(self, x): - return torch.remainder(x, 2.0) - - -@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float()) -def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3,))) - - -# ============================================================================== - - -class ElementwiseRemainderScalarModule_Float(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float, True), - ]) - def forward(self, x): - return torch.remainder(x, 2.0) - - -@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float()) -def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils): - module.forward(torch.rand(10, 3)) - - -# ============================================================================== - -class ElementwiseRemainderScalarModule_Int(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) - def forward(self, x): - return torch.remainder(x, 2) - - -@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int()) -def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 2))) - # ==============================================================================