diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index b52f0d47a..bf277481a 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -89,6 +89,8 @@ TORCHDYNAMO_XFAIL_SET = { "ReduceMaxAlongDimUnsignedInt_basic", #ERROR: value (-56) is not equal to golden value (200) "AtenIntTensorByteDtypeModule_basic", + # ERROR: assert isinstance(e, FakeTensor) + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", } STABLEHLO_PASS_SET = { @@ -155,6 +157,8 @@ STABLEHLO_PASS_SET = { "ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarInt64Module_basic", "ElementwiseAddScalarIntModule_basic", + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseDivScalarModule_basic", "ElementwiseEqDiffWidthScalarModule_basic", "ElementwiseEqFloatScalarModule_basic", @@ -537,6 +541,7 @@ TOSA_PASS_SET = { "ElementwiseDivScalarModule_basic", "ElementwiseSubScalarFloatModule_basic", "ElementwiseAddScalarFloatModule_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseMulScalarModule_float", "ElementwiseCeilModule_basic", "ElementwiseReciprocalModule_basic", diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 426bb750b..a214364e3 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -128,32 +128,36 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { return FloatAttr::get(Float64Type::get(context), value); } -static Value getScalarValue(Value input, Location loc, - PatternRewriter &rewriter) { +static Value getScalarIntValue(Value input, Location loc, + PatternRewriter &rewriter) { auto inputType = input.getType(); if (inputType.isa()) { return input; } - Value scalar = nullptr; + + auto inputTensorType = inputType.dyn_cast(); + if (!inputTensorType) + return nullptr; + + Type inputDtype = inputTensorType.getOptionalDtype(); + if (!inputDtype || !inputDtype.isInteger(64)) + return nullptr; + + std::optional inputRank = getTensorRank(input); + if (!inputRank || *inputRank != 0) + return nullptr; + if (auto valueTensorLiteralOp = input.getDefiningOp()) { - std::optional tensorRank = - getTensorRank(valueTensorLiteralOp.getResult()); - if (valueTensorLiteralOp && tensorRank && *tensorRank == 0) { - auto tensorType = - valueTensorLiteralOp.getValue().getType().cast(); - if (tensorType.getElementType().isa()) { - auto val = valueTensorLiteralOp.getValue() - .cast() - .getSplatValue(); - scalar = rewriter.create( - loc, rewriter.getI64IntegerAttr(val)); - } - } + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + return rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { - scalar = primNumToTensorScalarOp.getA(); + return primNumToTensorScalarOp.getA(); } - return scalar; + return nullptr; } //===----------------------------------------------------------------------===// @@ -869,8 +873,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, if (op->getNumOperands() < 2) { return failure(); } - auto lhs = getScalarValue(op->getOperand(0), loc, rewriter); - auto rhs = getScalarValue(op->getOperand(1), loc, rewriter); + auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter); + auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter); auto outType = op->getResult(0).getType(); if (!lhs || !rhs) { @@ -879,7 +883,7 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, } if (isa( op)) { - Value alpha = getScalarValue(op->getOperand(2), loc, rewriter); + Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter); if (!alpha) { return rewriter.notifyMatchFailure(op, "only int scalar alpha is supported"); diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index f8e109439..f2044bcc9 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1919,6 +1919,52 @@ def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAddScalar_NumToTensorFloat_Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + x = torch.ops.prim.NumToTensor(5.0) + return torch.add(x, 3) + + +@register_test_case( + module_factory=lambda: ElementwiseAddScalar_NumToTensorFloat_Module()) +def ElementwiseAddScalar_NumToTensorFloat_Module_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + +class ElementwiseAddScalar_TensorLiteralInt32_Module(torch.nn.Module): + + def __init__(self): + super().__init__() + self.x = torch.tensor(2, dtype=torch.int32) + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.add(self.x, 3) + + +@register_test_case( + module_factory=lambda: ElementwiseAddScalar_TensorLiteralInt32_Module()) +def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + class ElementwiseCloneModule(torch.nn.Module): def __init__(self):