diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index d5de9a6ed..7cc0fe233 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -1000,3 +1000,69 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): torch.randint(-10, 10, (3, 4))) +class ElementwiseSubScalarIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.sub(x, 2.1, alpha = 2) + +@register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule()) +def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (3, 4))) + + +class ElementwiseSubScalarFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.sub(x, 2.1) + +@register_test_case(module_factory=lambda: ElementwiseSubScalarFloatModule()) +def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +class ElementwiseAddScalarIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.add(x, 3.0) + +@register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule()) +def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (3, 4))) + + +class ElementwiseAddScalarFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.add(x, 3.0, alpha = 2) + +@register_test_case(module_factory=lambda: ElementwiseAddScalarFloatModule()) +def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index d2190b646..1c51dc3b3 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1676,6 +1676,42 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhs, scaled); } } + if (auto subScalar = dyn_cast(op)) { + Type dtype = converter->convertType(subScalar.getType()) + .cast() + .getElementType(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value other = convertScalarToDtype(b, loc, operands[1], dtype); + Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); + if (dtype.isa()) { + Value mult = b.create(loc, other, alpha); + return b.create(loc, self, mult); + } else if (dtype.isa()) { + Value mult = b.create(loc, other, alpha); + return b.create(loc, self, mult); + } + subScalar.emitError("unimplemented: dtype other than float and integer " + "types are not supported."); + return nullptr; + } + if (auto addScalar = dyn_cast(op)) { + Type dtype = converter->convertType(addScalar.getType()) + .cast() + .getElementType(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value other = convertScalarToDtype(b, loc, operands[1], dtype); + Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); + if (dtype.isa()) { + Value mult = b.create(loc, other, alpha); + return b.create(loc, self, mult); + } else if (dtype.isa()) { + Value mult = b.create(loc, other, alpha); + return b.create(loc, self, mult); + } + addScalar.emitError("unimplemented: dtype other than float and integer " + "types are not supported."); + return nullptr; + } if (auto mul = dyn_cast(op)) { AtenMulTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(mul.getType()) @@ -2244,7 +2280,8 @@ struct ConvertElementwiseOp : ConversionPattern { AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, - AtenEqTensorOp, AtenLtTensorOp>(op)) + AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp>( + op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter)))