diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 10ef39b9d..351c15dae 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6158,6 +6158,30 @@ def Torch_AtenIntFloatOp : Torch_Op<"aten.Int.float", [ }]; } +def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::Int.Scalar : (Scalar) -> (int)`"; + let arguments = (ins + AnyTorchScalarType:$a + ); + let results = (outs + Torch_IntType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIntScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIntScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [ AllowsTypeRefinement, HasValueSemantics, @@ -7139,6 +7163,29 @@ def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [ let hasFolder = 1; } +def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ScalarImplicit : (Tensor) -> (Scalar)`"; + let arguments = (ins + AnyTorchTensorType:$a + ); + let results = (outs + AnyTorchScalarType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScalarImplicitOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenScalarImplicitOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index b55085b09..f4d6b1c26 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -181,6 +181,20 @@ public: }; } // namespace +namespace { +class ConvertAtenScalarImplicitOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenScalarImplicitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.a()); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg:: populateTensorScalarInteropPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, @@ -201,4 +215,6 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + target.addIllegalOp(); } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 294871b18..27622ac6c 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1019,6 +1019,23 @@ OpFoldResult AtenFloatScalarOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenIntScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenIntScalarOp::fold(ArrayRef operands) { + // Constant fold float -> int conversion. + if (auto floatAttr = operands[0].dyn_cast_or_null()) { + return IntegerAttr::get( + mlir::IntegerType::get(getContext(), 64, IntegerType::Signed), + static_cast(floatAttr.getValue().convertToDouble())); + } + // If the input is int type already, the op is an identity. + if (getType() == getOperand().getType()) + return getOperand(); + return nullptr; +} + //===----------------------------------------------------------------------===// // NonValueTensorLiteralOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 5659c687f..aac10203e 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -426,6 +426,9 @@ private: ChangeResult visitBinaryScalarOp(Operation *op, ArrayRef *> operands); + ChangeResult visitAtenScalarImplicitOp( + AtenScalarImplicitOp op, + ArrayRef *> operands); }; } // namespace @@ -982,6 +985,9 @@ ChangeResult TypeAnalyzer::visitOperation( return visitBinaryScalarOp(op, operands); } + if (auto scalarImplicit = dyn_cast(op)) + return visitAtenScalarImplicitOp(scalarImplicit, operands); + // Otherwise, this is an unknown operation. Just mark all results as // having reached a pessimistic fixpoint. return markAllPessimisticFixpoint(op->getResults()); @@ -1249,6 +1255,19 @@ ChangeResult TypeAnalyzer::visitAten_SoftmaxLikeOp( return incorporateKnowledge(op.getResult(), knowledge); } +ChangeResult TypeAnalyzer::visitAtenScalarImplicitOp( + AtenScalarImplicitOp op, + ArrayRef *> operands) { + auto knowledge = + ValueKnowledge::getScalarPessimisticValueState(op.getContext()); + Type dType = operands[0]->getValue().dtype; + if (dType.isa()) + knowledge.setScalarType(Torch::FloatType::get(op->getContext())); + else if (dType.isa()) + knowledge.setScalarType(Torch::IntType::get(op->getContext())); + return incorporateKnowledge(op->getResult(0), knowledge); +} + // ----------------------------------------------------------------------------- // Transforms. // ----------------------------------------------------------------------------- diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index f7039eb95..15b31de32 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -477,6 +477,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) emit("aten::Float.str : (str) -> (float)") emit("aten::Int.float : (float) -> (int)") + emit("aten::Int.Scalar : (Scalar) -> (int)", has_folder=True) # Primitive ops emit("aten::__range_length : (int, int, int) -> (int)", has_folder=True) @@ -522,6 +523,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::eq.device : (Device, Device) -> (bool)") emit("aten::ceil.float : (float) -> (int)", has_folder=True) + emit("aten::ScalarImplicit : (Tensor) -> (Scalar)") # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 098a2c8ba..ca84b0cf8 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -621,6 +621,7 @@ def EmbeddingModuleI64_basic(module, tu: TestUtils): # ============================================================================== + class EmbeddingModuleI32(torch.nn.Module): def __init__(self): @@ -1816,8 +1817,10 @@ class ToCopyWithDTypeFalsePinMemoryModule(torch.nn.Module): def ToCopyWithDTypeFalsePinMemoryModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4)) + # ============================================================================== + class FlipModule(torch.nn.Module): def __init__(self): @@ -1857,3 +1860,43 @@ class DetachModule(torch.nn.Module): module_factory=lambda: DetachModule()) def DetachModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4)) + +# ============================================================================== + + +class ScalarImplicitFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ]) + def forward(self, x): + return float(torch.ops.aten.ScalarImplicit(x)) + + +@register_test_case(module_factory=lambda: ScalarImplicitFloatModule()) +def ScalarImplicitFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand().double()) + + +class ScalarImplicitIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ]) + def forward(self, x): + return int(torch.ops.aten.ScalarImplicit(x)) + + +@register_test_case(module_factory=lambda: ScalarImplicitIntModule()) +def ScalarImplicitIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-100, 100, ()))