[LINALG] Add E2E support for `aten.[Bool.Tensor|Float.Tensor]` op

- This commit adds lowering of `aten.Bool.Tensor` and
  `aten.Float.Tensor` op as a part of `convert-torch-to-linalg` pass.
- It also adds support for returning bool types.
- It also fixes lowering of the `aten.Int.Tensor` op for non-zero rank
  input tensors.
- If a scalar number is converted to a 0-d tensor and passed on to the
  `aten.Float.Tensor` op, it folds to the scalar number.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/593/head
Gaurav Shukla 2022-02-09 17:25:14 +05:30
parent 9e7b6cab08
commit f00d1686c8
11 changed files with 343 additions and 45 deletions

View File

@ -122,7 +122,6 @@ def AddmmModuleFloat_basic(module, tu: TestUtils):
# ==============================================================================
class AddmmModuleBroadcastable(torch.nn.Module):
def __init__(self):
super().__init__()
@ -144,7 +143,6 @@ def AddmmModule_broadcastable(module, tu: TestUtils):
# ==============================================================================
class AddmmModuleDifferentRankBroadcastable(torch.nn.Module):
def __init__(self):
super().__init__()
@ -166,7 +164,6 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils):
# ==============================================================================
class AdaptiveAvgPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -263,6 +260,7 @@ class MaxPool2dModule(torch.nn.Module):
def forward(self, x):
return self.mp2d(x)
# ==============================================================================
@register_test_case(module_factory=lambda: MaxPool2dModule())
def MaxPool2dModule_basic(module, tu: TestUtils):
@ -328,6 +326,7 @@ class ConstantPadNdModule(torch.nn.Module):
def ConstantPadNdModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
# ==============================================================================
class ConstantPadNdStaticModule(torch.nn.Module):
def __init__(self):
@ -346,6 +345,8 @@ class ConstantPadNdStaticModule(torch.nn.Module):
def ConstantPadNdStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
# ==============================================================================
class ConstantPadNdPartialStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -585,6 +586,8 @@ class SoftmaxIntModule(torch.nn.Module):
def SoftmaxIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
# ==============================================================================
class _SoftmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -718,22 +721,7 @@ class ContiguousModule(torch.nn.Module):
def ContiguousModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1))
class TensorToInt(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.int64, True),
])
def forward(self, x):
return int(x)
@register_test_case(module_factory=lambda: TensorToInt())
def TensorToInt_basic(module, tu: TestUtils):
module.forward(torch.randint(10,[]))
# ==============================================================================
class LogSoftmaxIntModule(torch.nn.Module):
def __init__(self):
@ -752,6 +740,7 @@ class LogSoftmaxIntModule(torch.nn.Module):
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).double())
# ==============================================================================
class NumToTensorIntModule(torch.nn.Module):
def __init__(self):
@ -769,6 +758,7 @@ class NumToTensorIntModule(torch.nn.Module):
def NumToTensorIntModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class NumToTensorFloatModule(torch.nn.Module):
def __init__(self):
@ -808,6 +798,8 @@ class ReturnThreeTensorFloat32(torch.nn.Module):
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))
# ==============================================================================
class AddCMulModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -827,6 +819,8 @@ class AddCMulModule(torch.nn.Module):
def AddCMulModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
# ==============================================================================
class AddCDivModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -865,6 +859,8 @@ class tensorIntModule(torch.nn.Module):
def TensorIntModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class tensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -902,6 +898,7 @@ class DropoutModule(torch.nn.Module):
def DropoutModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class MeanModule(torch.nn.Module):
def __init__(self):
@ -920,6 +917,7 @@ class MeanModule(torch.nn.Module):
def MeanModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 4))
# ==============================================================================
class MeanDynamicSizesModule(torch.nn.Module):
def __init__(self):
@ -938,6 +936,7 @@ class MeanDynamicSizesModule(torch.nn.Module):
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 4))
# ==============================================================================
class NumelModule(torch.nn.Module):
def __init__(self):
@ -955,6 +954,7 @@ class NumelModule(torch.nn.Module):
def NumelModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, 5))
# ==============================================================================
class NumelZeroRankModule(torch.nn.Module):
def __init__(self):
@ -972,6 +972,7 @@ class NumelZeroRankModule(torch.nn.Module):
def NumelZeroRankModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10,[]))
# ==============================================================================
class BoolTensorReturnFalseModule(torch.nn.Module):
def __init__(self):
@ -990,6 +991,7 @@ class BoolTensorReturnFalseModule(torch.nn.Module):
def BoolTensorReturnFalseModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0], dtype=torch.bool))
# ==============================================================================
class BoolTensorReturnTrueModule(torch.nn.Module):
def __init__(self):
@ -1008,6 +1010,7 @@ class BoolTensorReturnTrueModule(torch.nn.Module):
def BoolTensorReturnTrueModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([1, 1, 1, 1, 1], dtype=torch.bool))
# ==============================================================================
class BoolTensorReturnMixedModule(torch.nn.Module):
def __init__(self):

View File

@ -0,0 +1,125 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class TensorToIntZeroRank(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.int64, True),
])
def forward(self, x):
return int(x)
@register_test_case(module_factory=lambda: TensorToIntZeroRank())
def TensorToIntZeroRank_basic(module, tu: TestUtils):
module.forward(torch.randint(10, ()))
# ==============================================================================
class TensorToInt(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return int(x)
@register_test_case(module_factory=lambda: TensorToInt())
def TensorToInt_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (1, 1)))
# ==============================================================================
class TensorToFloatZeroRank(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float64, True),
])
def forward(self, x):
return float(x)
@register_test_case(module_factory=lambda: TensorToFloatZeroRank())
def TensorToFloatZeroRank_basic(module, tu: TestUtils):
module.forward(torch.rand((), dtype=torch.float64))
# ==============================================================================
class TensorToFloat(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float64, True),
])
def forward(self, x):
return float(x)
@register_test_case(module_factory=lambda: TensorToFloat())
def TensorToFloat_basic(module, tu: TestUtils):
module.forward(torch.rand((1, 1), dtype=torch.float64))
# ==============================================================================
class TensorToBoolZeroRank(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.bool, True),
])
def forward(self, x):
return bool(x)
@register_test_case(module_factory=lambda: TensorToBoolZeroRank())
def TensorToBoolZeroRank_basic(module, tu: TestUtils):
module.forward(torch.tensor(1, dtype=torch.bool))
# ==============================================================================
class TensorToBool(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, x):
return bool(x)
@register_test_case(module_factory=lambda: TensorToBool())
def TensorToBool_basic(module, tu: TestUtils):
module.forward(torch.tensor([[1]], dtype=torch.bool))

View File

@ -51,6 +51,7 @@ from . import threshold
from . import histogram_binning_calibration
from . import table_batch_embedding
from . import rng
from . import cast
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -2935,6 +2935,21 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
let hasFolder = 1;
}
def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::Float.Tensor : (Tensor) -> (float)`";
let arguments = (ins
AnyTorchTensorType:$a
);
let results = (outs
Torch_FloatType:$result
);
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
let hasFolder = 1;
}
def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -3804,30 +3804,41 @@ public:
};
} // namespace
// Casts a 0d integer tensor to elemental type.
namespace {
class ConvertAtenIntTensorOp : public OpConversionPattern<AtenIntTensorOp> {
// Casts a tensor of exactly one element to an elemental type.
template <typename OpTy>
class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenIntTensorOp op, OpAdaptor adaptor,
matchAndRewrite(OpTy op,
typename OpConversionPattern<OpTy>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value intTensor = adaptor.a();
auto tensorType = intTensor.getType().cast<RankedTensorType>();
Location loc = op.getLoc();
Value input = adaptor.a();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
int64_t inputRank = inputSizes.size();
if (tensorType.getRank() != 0)
return rewriter.notifyMatchFailure(
op, "invalid rank: the rank of the input tensor must be 0");
// The `input` tensor must contain exactly one element, i.e., either the
// `input` is a zero rank tensor or all the dimensions of the `input` tensor
// are unit.
Value constantOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
for (int64_t i = 0; i < inputRank; i++)
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, intTensor);
// Extract the only element from the `input` tensor.
Value constantZero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
SmallVector<Value> indices(inputRank, constantZero);
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, input, indices);
return success();
}
};
} // namespace
namespace {
class ConvertAtenFill_ScalarOp : public OpConversionPattern<AtenFill_ScalarOp> {
public:
@ -3853,7 +3864,6 @@ public:
};
} // namespace
namespace {
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
public:
@ -4618,8 +4628,13 @@ public:
context);
target.addIllegalOp<AtenContiguousOp>();
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<AtenIntTensorOp>();
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
target.addIllegalOp<AtenIntTensorOp, AtenFloatTensorOp, AtenBoolTensorOp>();
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenIntTensorOp>>(
typeConverter, context);
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenFloatTensorOp>>(
typeConverter, context);
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenBoolTensorOp>>(
typeConverter, context);
target.addIllegalOp<PrimNumToTensorScalarOp>();
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
target.addIllegalOp<AtenDropoutOp>();

View File

@ -1229,13 +1229,31 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenIntTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
// If an scalar number is converted to a 0-d tensor and passed on to
// If a scalar number is converted to a 0-d tensor and passed on to
// aten.Int.Tensor, fold to the scalar number.
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
return numToTensorScalar.a();
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenFloatTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
// If a scalar number is converted to a 0-d tensor and passed on to
// aten.Float.Tensor, fold to the scalar number.
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
return numToTensorScalar.a();
return nullptr;
}
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"

View File

@ -169,7 +169,7 @@ static LogicalResult mungeFunction(
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
op.emitError("Supported return types:"
"mri1, mri32, mri64, mrf32, mrf64, i64, f32, f64,"
"mri1, mri32, mri64, mrf32, mrf64, i1, i64, f32, f64,"
"(mrf32, mri64), (mrf32, mrf32), (mrf64, mrf64),"
"(mrf32, mrf32, mrf32)");
isSupported = false;
@ -195,6 +195,7 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
Type i1 = b.getI1Type();
Type i64 = b.getI64Type();
Type f32 = b.getF32Type();
Type f64 = b.getF64Type();
@ -204,6 +205,7 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
mri64,
mrf32,
mrf64,
i1,
i64,
f32,
f64,

View File

@ -622,6 +622,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::IntImplicit : (Tensor) -> (int)")
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::t : (Tensor) -> (Tensor)")

View File

@ -53,6 +53,10 @@ class RefBackendInvoker:
def consume_return_mrf64(a):
self.result = unranked_memref_to_numpy(a, np.float64)
@ctypes.CFUNCTYPE(None, ctypes.c_bool)
def consume_return_i1(a):
self.result = a
@ctypes.CFUNCTYPE(None, ctypes.c_int)
def consume_return_i64(a):
self.result = a
@ -113,6 +117,9 @@ class RefBackendInvoker:
self.ee.register_runtime("refbackend_consume_func_return_mrf64",
consume_return_mrf64)
self.ee.register_runtime("refbackend_consume_func_return_i1",
consume_return_i1)
self.ee.register_runtime("refbackend_consume_func_return_i64",
consume_return_i64)

View File

@ -57,16 +57,117 @@ func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],
// -----
// CHECK-LABEL: func @integer_extract
// CHECK-SAME: (%[[A:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i64>
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.int
func @integer_extract(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.aten.Int.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][] : tensor<i64>
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.int
func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: func @torch.aten.Int.Tensor$non_zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.int {
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[I]], %[[C0]] : tensor<?x?xi64>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[I]], %[[C1]] : tensor<?x?xi64>
// CHECK: %[[ONE:.*]] = arith.constant 1 : i64
// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64
// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64
// CHECK: assert %[[PRED0]], "mismatching contracting dimension"
// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64
// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64
// CHECK: assert %[[PRED1]], "mismatching contracting dimension"
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi64>
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.int
func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[?,?],si64> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: func @torch.aten.Float.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float {
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor<f64>
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][] : tensor<f64>
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.float
func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float {
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[],f64> -> !torch.float
return %0 : !torch.float
}
// -----
// CHECK-LABEL: func @torch.aten.Float.Tensor$non_zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.float {
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f64> -> tensor<?x?xf64>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[F]], %[[C0]] : tensor<?x?xf64>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[F]], %[[C1]] : tensor<?x?xf64>
// CHECK: %[[ONE:.*]] = arith.constant 1 : i64
// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64
// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64
// CHECK: assert %[[PRED0]], "mismatching contracting dimension"
// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64
// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64
// CHECK: assert %[[PRED1]], "mismatching contracting dimension"
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xf64>
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.float
func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float {
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[?,?],f64> -> !torch.float
return %0 : !torch.float
}
// -----
// CHECK-LABEL: func @torch.aten.Bool.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],i1>) -> !torch.bool {
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],i1> -> tensor<i1>
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1>
// CHECK: %[[RES:.*]] = torch_c.from_i1 %[[EXT]]
// CHECK: return %[[RES]] : !torch.bool
func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool {
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[],i1> -> !torch.bool
return %0 : !torch.bool
}
// -----
// CHECK-LABEL: func @torch.aten.Bool.Tensor$non_zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.bool {
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xi1>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xi1>
// CHECK: %[[ONE:.*]] = arith.constant 1 : i64
// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64
// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64
// CHECK: assert %[[PRED0]], "mismatching contracting dimension"
// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64
// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64
// CHECK: assert %[[PRED1]], "mismatching contracting dimension"
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi1>
// CHECK: %[[RET:.*]] = torch_c.from_i1 %[[EXT]]
// CHECK: return %[[RET]] : !torch.bool
func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool {
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[?,?],i1> -> !torch.bool
return %0 : !torch.bool
}
// -----

View File

@ -721,6 +721,16 @@ func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
return %scalar : !torch.int
}
// CHECK-LABEL: func @torch.aten.Float.Tensor(
// CHECK-SAME: %[[NUM:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.float -> !torch.vtensor<[],f64>
// CHECK: return %[[NUM]] : !torch.float
func @torch.aten.Float.Tensor(%arg0: !torch.float) -> !torch.float {
%tensor = torch.prim.NumToTensor.Scalar %arg0: !torch.float -> !torch.vtensor<[],f64>
%scalar = torch.aten.Float.Tensor %tensor : !torch.vtensor<[],f64> -> !torch.float
return %scalar : !torch.float
}
// CHECK-LABEL: func @torch.aten.squeeze$zero_rank(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>