diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 196a67e5b..6824e05a1 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -178,7 +178,10 @@ TOSA_PASS_SET = { "TransposeIntNegDimsModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", - "_LogSoftmaxModuleStable_basic", + "_LogSoftmaxModuleStable_basic", + "ElementwiseRemainderScalarModule_Int_Float_basic", + "ElementwiseRemainderScalarModule_Float_basic", + "ElementwiseRemainderScalarModule_Int_basic", } LTC_XFAIL_SET = { diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6d6459d5d..75c2d98fe 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7883,6 +7883,30 @@ def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [ let hasFolder = 1; } +def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRemainderScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenRemainderScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index f11f8a52e..2453bb2d3 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -803,6 +803,23 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value other = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, self, other); } + if (auto remScalar = dyn_cast(op)) { + Type newResultType = converter->convertType(remScalar.getType()) + .cast() + .getElementType(); + + Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); + Value other = convertScalarToDtype(b, loc, operands[1], newResultType); + Value result; + + if (newResultType.isa()) { + result = b.create(loc, self, other); + } else { + result = b.create(loc, self, other); + } + + return result; + } if (auto reciprocal = dyn_cast(op)) { Type dtype = converter->convertType(reciprocal.getType()) .cast() @@ -943,14 +960,14 @@ public: AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, - AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, - AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, - AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, - AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, - AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, - AtenTriuOp>(op)) + AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, + AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, + AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, + AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, + AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, + AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, + AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, + AtenLogicalOrOp, AtenTriuOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1681,14 +1698,14 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, - AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, - AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, - AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, - AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>(); + AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, + AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, + AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, + AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, + AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, + AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, + AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, + AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index b32e8b602..67c9a8046 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -770,7 +770,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote LHS with scalar RHS. if (isa(op)) { + AtenRsubScalarOp, AtenLeakyReluOp, AtenRemainderScalarOp>(op)) { auto lhs = operands[0]->getValue(); Value scalar = op->getOperand(1); auto knowledge = diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 610fdfa6c..f5e220050 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5441,6 +5441,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index b4865c6c2..83e67ea27 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -478,6 +478,9 @@ def aten〇mul〇Scalar(self: List[int], other: float) -> List[int]: def aten〇div〇Scalar(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇remainder〇Scalar(self: List[int], other: float) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇floor_divide〇Scalar(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 1e28b74b9..c86aaf804 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -549,6 +549,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::eq.int : (int, int) -> (bool)", has_folder=True) emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) + emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::mul.int : (int, int) -> (int)", has_folder=True) diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 536982c06..5cea1d4f7 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1304,6 +1304,70 @@ class ElementwiseDivScalarModule(torch.nn.Module): def ElementwiseDivScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) +# ============================================================================== + + +class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ]) + def forward(self, x): + return torch.remainder(x, 2.0) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float()) +def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (3,))) + + +# ============================================================================== + + +class ElementwiseRemainderScalarModule_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float, True), + ]) + def forward(self, x): + return torch.remainder(x, 2.0) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float()) +def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils): + module.forward(torch.rand(10, 3)) + + +# ============================================================================== + +class ElementwiseRemainderScalarModule_Int(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.remainder(x, 2) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int()) +def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (3, 2))) + # ==============================================================================