diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index d7d140c69..87032b123 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -122,7 +122,6 @@ def AddmmModuleFloat_basic(module, tu: TestUtils): # ============================================================================== - class AddmmModuleBroadcastable(torch.nn.Module): def __init__(self): super().__init__() @@ -144,7 +143,6 @@ def AddmmModule_broadcastable(module, tu: TestUtils): # ============================================================================== - class AddmmModuleDifferentRankBroadcastable(torch.nn.Module): def __init__(self): super().__init__() @@ -166,7 +164,6 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils): # ============================================================================== - class AdaptiveAvgPool2dModule(torch.nn.Module): def __init__(self): super().__init__() @@ -263,6 +260,7 @@ class MaxPool2dModule(torch.nn.Module): def forward(self, x): return self.mp2d(x) +# ============================================================================== @register_test_case(module_factory=lambda: MaxPool2dModule()) def MaxPool2dModule_basic(module, tu: TestUtils): @@ -328,6 +326,7 @@ class ConstantPadNdModule(torch.nn.Module): def ConstantPadNdModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5) +# ============================================================================== class ConstantPadNdStaticModule(torch.nn.Module): def __init__(self): @@ -346,6 +345,8 @@ class ConstantPadNdStaticModule(torch.nn.Module): def ConstantPadNdStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5) +# ============================================================================== + class ConstantPadNdPartialStaticModule(torch.nn.Module): def __init__(self): super().__init__() @@ -585,6 +586,8 @@ class SoftmaxIntModule(torch.nn.Module): def SoftmaxIntModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4)) +# ============================================================================== + class _SoftmaxModule(torch.nn.Module): def __init__(self): super().__init__() @@ -718,22 +721,7 @@ class ContiguousModule(torch.nn.Module): def ContiguousModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1)) -class TensorToInt(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) - def forward(self, x): - return int(x) - - -@register_test_case(module_factory=lambda: TensorToInt()) -def TensorToInt_basic(module, tu: TestUtils): - module.forward(torch.randint(10,[])) +# ============================================================================== class LogSoftmaxIntModule(torch.nn.Module): def __init__(self): @@ -752,6 +740,7 @@ class LogSoftmaxIntModule(torch.nn.Module): def LogSoftmaxIntModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4).double()) +# ============================================================================== class NumToTensorIntModule(torch.nn.Module): def __init__(self): @@ -769,6 +758,7 @@ class NumToTensorIntModule(torch.nn.Module): def NumToTensorIntModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class NumToTensorFloatModule(torch.nn.Module): def __init__(self): @@ -808,6 +798,8 @@ class ReturnThreeTensorFloat32(torch.nn.Module): def ReturnThreeTensorFloat32_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3)) +# ============================================================================== + class AddCMulModule(torch.nn.Module): def __init__(self): super().__init__() @@ -827,6 +819,8 @@ class AddCMulModule(torch.nn.Module): def AddCMulModule_basic(module, tu: TestUtils): module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3)) +# ============================================================================== + class AddCDivModule(torch.nn.Module): def __init__(self): super().__init__() @@ -865,6 +859,8 @@ class tensorIntModule(torch.nn.Module): def TensorIntModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== + class tensorFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -902,6 +898,7 @@ class DropoutModule(torch.nn.Module): def DropoutModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) +# ============================================================================== class MeanModule(torch.nn.Module): def __init__(self): @@ -920,6 +917,7 @@ class MeanModule(torch.nn.Module): def MeanModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 4)) +# ============================================================================== class MeanDynamicSizesModule(torch.nn.Module): def __init__(self): @@ -938,6 +936,7 @@ class MeanDynamicSizesModule(torch.nn.Module): def MeanDynamicSizesModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 4)) +# ============================================================================== class NumelModule(torch.nn.Module): def __init__(self): @@ -955,6 +954,7 @@ class NumelModule(torch.nn.Module): def NumelModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 3, 5)) +# ============================================================================== class NumelZeroRankModule(torch.nn.Module): def __init__(self): @@ -972,6 +972,7 @@ class NumelZeroRankModule(torch.nn.Module): def NumelZeroRankModule_basic(module, tu: TestUtils): module.forward(torch.randint(10,[])) +# ============================================================================== class BoolTensorReturnFalseModule(torch.nn.Module): def __init__(self): @@ -990,6 +991,7 @@ class BoolTensorReturnFalseModule(torch.nn.Module): def BoolTensorReturnFalseModule_basic(module, tu: TestUtils): module.forward(torch.tensor([0, 0], dtype=torch.bool)) +# ============================================================================== class BoolTensorReturnTrueModule(torch.nn.Module): def __init__(self): @@ -1008,6 +1010,7 @@ class BoolTensorReturnTrueModule(torch.nn.Module): def BoolTensorReturnTrueModule_basic(module, tu: TestUtils): module.forward(torch.tensor([1, 1, 1, 1, 1], dtype=torch.bool)) +# ============================================================================== class BoolTensorReturnMixedModule(torch.nn.Module): def __init__(self): diff --git a/e2e_testing/torchscript/cast.py b/e2e_testing/torchscript/cast.py new file mode 100644 index 000000000..834066949 --- /dev/null +++ b/e2e_testing/torchscript/cast.py @@ -0,0 +1,125 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + +class TensorToIntZeroRank(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ]) + def forward(self, x): + return int(x) + + +@register_test_case(module_factory=lambda: TensorToIntZeroRank()) +def TensorToIntZeroRank_basic(module, tu: TestUtils): + module.forward(torch.randint(10, ())) + +# ============================================================================== + +class TensorToInt(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return int(x) + + +@register_test_case(module_factory=lambda: TensorToInt()) +def TensorToInt_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (1, 1))) + +# ============================================================================== + +class TensorToFloatZeroRank(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ]) + def forward(self, x): + return float(x) + + +@register_test_case(module_factory=lambda: TensorToFloatZeroRank()) +def TensorToFloatZeroRank_basic(module, tu: TestUtils): + module.forward(torch.rand((), dtype=torch.float64)) + +# ============================================================================== + +class TensorToFloat(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float64, True), + ]) + def forward(self, x): + return float(x) + + +@register_test_case(module_factory=lambda: TensorToFloat()) +def TensorToFloat_basic(module, tu: TestUtils): + module.forward(torch.rand((1, 1), dtype=torch.float64)) + +# ============================================================================== + +class TensorToBoolZeroRank(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.bool, True), + ]) + def forward(self, x): + return bool(x) + + +@register_test_case(module_factory=lambda: TensorToBoolZeroRank()) +def TensorToBoolZeroRank_basic(module, tu: TestUtils): + module.forward(torch.tensor(1, dtype=torch.bool)) + +# ============================================================================== + +class TensorToBool(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, x): + return bool(x) + + +@register_test_case(module_factory=lambda: TensorToBool()) +def TensorToBool_basic(module, tu: TestUtils): + module.forward(torch.tensor([[1]], dtype=torch.bool)) + diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index ce54fa48b..5b8e515af 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -51,6 +51,7 @@ from . import threshold from . import histogram_binning_calibration from . import table_batch_embedding from . import rng +from . import cast def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index defad0272..99e81c847 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -2935,6 +2935,21 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [ let hasFolder = 1; } +def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::Float.Tensor : (Tensor) -> (float)`"; + let arguments = (ins + AnyTorchTensorType:$a + ); + let results = (outs + Torch_FloatType:$result + ); + let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))"; + let hasFolder = 1; +} + def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 20ebc2db8..8a26fd601 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -3804,30 +3804,41 @@ public: }; } // namespace -// Casts a 0d integer tensor to elemental type. namespace { -class ConvertAtenIntTensorOp : public OpConversionPattern { +// Casts a tensor of exactly one element to an elemental type. +template +class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenIntTensorOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, + typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - Value intTensor = adaptor.a(); - auto tensorType = intTensor.getType().cast(); + Location loc = op.getLoc(); + Value input = adaptor.a(); + SmallVector inputSizes = getTensorSizes(rewriter, loc, input); + int64_t inputRank = inputSizes.size(); - if (tensorType.getRank() != 0) - return rewriter.notifyMatchFailure( - op, "invalid rank: the rank of the input tensor must be 0"); + // The `input` tensor must contain exactly one element, i.e., either the + // `input` is a zero rank tensor or all the dimensions of the `input` tensor + // are unit. + Value constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + for (int64_t i = 0; i < inputRank; i++) + checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); - rewriter.replaceOpWithNewOp(op, intTensor); + // Extract the only element from the `input` tensor. + Value constantZero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + SmallVector indices(inputRank, constantZero); + rewriter.replaceOpWithNewOp(op, input, indices); return success(); } }; } // namespace - namespace { class ConvertAtenFill_ScalarOp : public OpConversionPattern { public: @@ -3853,7 +3864,6 @@ public: }; } // namespace - namespace { class ConvertAtenBroadcastToOp : public OpConversionPattern { public: @@ -4618,8 +4628,13 @@ public: context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 1ad23b8d6..41b8c5a58 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1229,13 +1229,31 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenIntTensorOp +//===----------------------------------------------------------------------===// + OpFoldResult AtenIntTensorOp::fold(ArrayRef operands) { - // If an scalar number is converted to a 0-d tensor and passed on to + // If a scalar number is converted to a 0-d tensor and passed on to // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = a().getDefiningOp()) return numToTensorScalar.a(); return nullptr; } +//===----------------------------------------------------------------------===// +// AtenFloatTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenFloatTensorOp::fold(ArrayRef operands) { + // If a scalar number is converted to a 0-d tensor and passed on to + // aten.Float.Tensor, fold to the scalar number. + if (auto numToTensorScalar = a().getDefiningOp()) + return numToTensorScalar.a(); + return nullptr; +} + +//===----------------------------------------------------------------------===// + #define GET_OP_CLASSES #include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc" diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 964f24869..2ef5ad755 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -169,7 +169,7 @@ static LogicalResult mungeFunction( std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes); if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) { op.emitError("Supported return types:" - "mri1, mri32, mri64, mrf32, mrf64, i64, f32, f64," + "mri1, mri32, mri64, mrf32, mrf64, i1, i64, f32, f64," "(mrf32, mri64), (mrf32, mrf32), (mrf64, mrf64)," "(mrf32, mrf32, mrf32)"); isSupported = false; @@ -195,6 +195,7 @@ static std::set getSupportedConsumeFuncReturnFuncs(OpBuilder &b) { Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0); Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0); Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0); + Type i1 = b.getI1Type(); Type i64 = b.getI64Type(); Type f32 = b.getF32Type(); Type f64 = b.getF64Type(); @@ -204,6 +205,7 @@ static std::set getSupportedConsumeFuncReturnFuncs(OpBuilder &b) { mri64, mrf32, mrf64, + i1, i64, f32, f64, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index b26cbc6ee..814bbf8f4 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -622,6 +622,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::IntImplicit : (Tensor) -> (int)") emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True) + emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit("aten::t : (Tensor) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 9006b9d73..43da873cc 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -53,6 +53,10 @@ class RefBackendInvoker: def consume_return_mrf64(a): self.result = unranked_memref_to_numpy(a, np.float64) + @ctypes.CFUNCTYPE(None, ctypes.c_bool) + def consume_return_i1(a): + self.result = a + @ctypes.CFUNCTYPE(None, ctypes.c_int) def consume_return_i64(a): self.result = a @@ -113,6 +117,9 @@ class RefBackendInvoker: self.ee.register_runtime("refbackend_consume_func_return_mrf64", consume_return_mrf64) + self.ee.register_runtime("refbackend_consume_func_return_i1", + consume_return_i1) + self.ee.register_runtime("refbackend_consume_func_return_i64", consume_return_i64) diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 96e03a079..040f7d7c4 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -57,16 +57,117 @@ func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?], // ----- -// CHECK-LABEL: func @integer_extract -// CHECK-SAME: (%[[A:.*]]: !torch.vtensor<[],si64>) -> !torch.int { -// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[],si64> -> tensor -// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor -// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]] -// CHECK: return %[[RET]] : !torch.int -func @integer_extract(%arg0: !torch.vtensor<[],si64>) -> !torch.int { - %0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int - return %0 : !torch.int -} +// CHECK-LABEL: func @torch.aten.Int.Tensor$zero_rank +// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si64>) -> !torch.int { +// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si64> -> tensor +// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][] : tensor +// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]] +// CHECK: return %[[RET]] : !torch.int +func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.int { + %0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int + return %0 : !torch.int +} + +// ----- + +// CHECK-LABEL: func @torch.aten.Int.Tensor$non_zero_rank +// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.int { +// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = tensor.dim %[[I]], %[[C0]] : tensor +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = tensor.dim %[[I]], %[[C1]] : tensor +// CHECK: %[[ONE:.*]] = arith.constant 1 : i64 +// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64 +// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64 +// CHECK: assert %[[PRED0]], "mismatching contracting dimension" +// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64 +// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64 +// CHECK: assert %[[PRED1]], "mismatching contracting dimension" +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor +// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]] +// CHECK: return %[[RET]] : !torch.int +func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int { + %0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[?,?],si64> -> !torch.int + return %0 : !torch.int +} + +// ----- + +// CHECK-LABEL: func @torch.aten.Float.Tensor$zero_rank +// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float { +// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor +// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][] : tensor +// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]] +// CHECK: return %[[RET]] : !torch.float +func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float { + %0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[],f64> -> !torch.float + return %0 : !torch.float +} + +// ----- + +// CHECK-LABEL: func @torch.aten.Float.Tensor$non_zero_rank +// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.float { +// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f64> -> tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = tensor.dim %[[F]], %[[C0]] : tensor +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = tensor.dim %[[F]], %[[C1]] : tensor +// CHECK: %[[ONE:.*]] = arith.constant 1 : i64 +// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64 +// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64 +// CHECK: assert %[[PRED0]], "mismatching contracting dimension" +// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64 +// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64 +// CHECK: assert %[[PRED1]], "mismatching contracting dimension" +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][%[[ZERO]], %[[ZERO]]] : tensor +// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]] +// CHECK: return %[[RET]] : !torch.float +func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float { + %0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[?,?],f64> -> !torch.float + return %0 : !torch.float +} + +// ----- + +// CHECK-LABEL: func @torch.aten.Bool.Tensor$zero_rank +// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],i1>) -> !torch.bool { +// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],i1> -> tensor +// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor +// CHECK: %[[RES:.*]] = torch_c.from_i1 %[[EXT]] +// CHECK: return %[[RES]] : !torch.bool +func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool { + %0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[],i1> -> !torch.bool + return %0 : !torch.bool +} + +// ----- + +// CHECK-LABEL: func @torch.aten.Bool.Tensor$non_zero_rank +// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.bool { +// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor +// CHECK: %[[ONE:.*]] = arith.constant 1 : i64 +// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64 +// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64 +// CHECK: assert %[[PRED0]], "mismatching contracting dimension" +// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64 +// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64 +// CHECK: assert %[[PRED1]], "mismatching contracting dimension" +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor +// CHECK: %[[RET:.*]] = torch_c.from_i1 %[[EXT]] +// CHECK: return %[[RET]] : !torch.bool +func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool { + %0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[?,?],i1> -> !torch.bool + return %0 : !torch.bool +} // ----- diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index c8d86f1ab..e36a5f924 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -721,6 +721,16 @@ func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int { return %scalar : !torch.int } +// CHECK-LABEL: func @torch.aten.Float.Tensor( +// CHECK-SAME: %[[NUM:.*]]: !torch.float) -> !torch.float { +// CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.float -> !torch.vtensor<[],f64> +// CHECK: return %[[NUM]] : !torch.float +func @torch.aten.Float.Tensor(%arg0: !torch.float) -> !torch.float { + %tensor = torch.prim.NumToTensor.Scalar %arg0: !torch.float -> !torch.vtensor<[],f64> + %scalar = torch.aten.Float.Tensor %tensor : !torch.vtensor<[],f64> -> !torch.float + return %scalar : !torch.float +} + // CHECK-LABEL: func @torch.aten.squeeze$zero_rank( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>