mirror of https://github.com/llvm/torch-mlir
[LINALG] Add E2E support for `aten.[Bool.Tensor|Float.Tensor]` op
- This commit adds lowering of `aten.Bool.Tensor` and `aten.Float.Tensor` op as a part of `convert-torch-to-linalg` pass. - It also adds support for returning bool types. - It also fixes lowering of the `aten.Int.Tensor` op for non-zero rank input tensors. - If a scalar number is converted to a 0-d tensor and passed on to the `aten.Float.Tensor` op, it folds to the scalar number. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/593/head
parent
9e7b6cab08
commit
f00d1686c8
|
@ -122,7 +122,6 @@ def AddmmModuleFloat_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AddmmModuleBroadcastable(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -144,7 +143,6 @@ def AddmmModule_broadcastable(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AddmmModuleDifferentRankBroadcastable(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -166,7 +164,6 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AdaptiveAvgPool2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -263,6 +260,7 @@ class MaxPool2dModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return self.mp2d(x)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxPool2dModule())
|
||||
def MaxPool2dModule_basic(module, tu: TestUtils):
|
||||
|
@ -328,6 +326,7 @@ class ConstantPadNdModule(torch.nn.Module):
|
|||
def ConstantPadNdModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ConstantPadNdStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -346,6 +345,8 @@ class ConstantPadNdStaticModule(torch.nn.Module):
|
|||
def ConstantPadNdStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ConstantPadNdPartialStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -585,6 +586,8 @@ class SoftmaxIntModule(torch.nn.Module):
|
|||
def SoftmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class _SoftmaxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -718,22 +721,7 @@ class ContiguousModule(torch.nn.Module):
|
|||
def ContiguousModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1))
|
||||
|
||||
class TensorToInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return int(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorToInt())
|
||||
def TensorToInt_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10,[]))
|
||||
# ==============================================================================
|
||||
|
||||
class LogSoftmaxIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -752,6 +740,7 @@ class LogSoftmaxIntModule(torch.nn.Module):
|
|||
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NumToTensorIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -769,6 +758,7 @@ class NumToTensorIntModule(torch.nn.Module):
|
|||
def NumToTensorIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NumToTensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -808,6 +798,8 @@ class ReturnThreeTensorFloat32(torch.nn.Module):
|
|||
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AddCMulModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -827,6 +819,8 @@ class AddCMulModule(torch.nn.Module):
|
|||
def AddCMulModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AddCDivModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -865,6 +859,8 @@ class tensorIntModule(torch.nn.Module):
|
|||
def TensorIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class tensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -902,6 +898,7 @@ class DropoutModule(torch.nn.Module):
|
|||
def DropoutModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -920,6 +917,7 @@ class MeanModule(torch.nn.Module):
|
|||
def MeanModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDynamicSizesModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -938,6 +936,7 @@ class MeanDynamicSizesModule(torch.nn.Module):
|
|||
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NumelModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -955,6 +954,7 @@ class NumelModule(torch.nn.Module):
|
|||
def NumelModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NumelZeroRankModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -972,6 +972,7 @@ class NumelZeroRankModule(torch.nn.Module):
|
|||
def NumelZeroRankModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10,[]))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BoolTensorReturnFalseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -990,6 +991,7 @@ class BoolTensorReturnFalseModule(torch.nn.Module):
|
|||
def BoolTensorReturnFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([0, 0], dtype=torch.bool))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BoolTensorReturnTrueModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1008,6 +1010,7 @@ class BoolTensorReturnTrueModule(torch.nn.Module):
|
|||
def BoolTensorReturnTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([1, 1, 1, 1, 1], dtype=torch.bool))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BoolTensorReturnMixedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TensorToIntZeroRank(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return int(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorToIntZeroRank())
|
||||
def TensorToIntZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, ()))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TensorToInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return int(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorToInt())
|
||||
def TensorToInt_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (1, 1)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TensorToFloatZeroRank(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return float(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorToFloatZeroRank())
|
||||
def TensorToFloatZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand((), dtype=torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TensorToFloat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return float(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorToFloat())
|
||||
def TensorToFloat_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand((1, 1), dtype=torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TensorToBoolZeroRank(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.bool, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return bool(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorToBoolZeroRank())
|
||||
def TensorToBoolZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor(1, dtype=torch.bool))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TensorToBool(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return bool(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorToBool())
|
||||
def TensorToBool_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[1]], dtype=torch.bool))
|
||||
|
|
@ -51,6 +51,7 @@ from . import threshold
|
|||
from . import histogram_binning_calibration
|
||||
from . import table_batch_embedding
|
||||
from . import rng
|
||||
from . import cast
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -2935,6 +2935,21 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::Float.Tensor : (Tensor) -> (float)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$a
|
||||
);
|
||||
let results = (outs
|
||||
Torch_FloatType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -3804,30 +3804,41 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Casts a 0d integer tensor to elemental type.
|
||||
namespace {
|
||||
class ConvertAtenIntTensorOp : public OpConversionPattern<AtenIntTensorOp> {
|
||||
// Casts a tensor of exactly one element to an elemental type.
|
||||
template <typename OpTy>
|
||||
class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenIntTensorOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(OpTy op,
|
||||
typename OpConversionPattern<OpTy>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Value intTensor = adaptor.a();
|
||||
auto tensorType = intTensor.getType().cast<RankedTensorType>();
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.a();
|
||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
int64_t inputRank = inputSizes.size();
|
||||
|
||||
if (tensorType.getRank() != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid rank: the rank of the input tensor must be 0");
|
||||
// The `input` tensor must contain exactly one element, i.e., either the
|
||||
// `input` is a zero rank tensor or all the dimensions of the `input` tensor
|
||||
// are unit.
|
||||
Value constantOne =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
for (int64_t i = 0; i < inputRank; i++)
|
||||
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, intTensor);
|
||||
// Extract the only element from the `input` tensor.
|
||||
Value constantZero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
SmallVector<Value> indices(inputRank, constantZero);
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, input, indices);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace {
|
||||
class ConvertAtenFill_ScalarOp : public OpConversionPattern<AtenFill_ScalarOp> {
|
||||
public:
|
||||
|
@ -3853,7 +3864,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||
public:
|
||||
|
@ -4618,8 +4628,13 @@ public:
|
|||
context);
|
||||
target.addIllegalOp<AtenContiguousOp>();
|
||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIntTensorOp>();
|
||||
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIntTensorOp, AtenFloatTensorOp, AtenBoolTensorOp>();
|
||||
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenIntTensorOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenFloatTensorOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenBoolTensorOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<PrimNumToTensorScalarOp>();
|
||||
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenDropoutOp>();
|
||||
|
|
|
@ -1229,13 +1229,31 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIntTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||
// If an scalar number is converted to a 0-d tensor and passed on to
|
||||
// If a scalar number is converted to a 0-d tensor and passed on to
|
||||
// aten.Int.Tensor, fold to the scalar number.
|
||||
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||
return numToTensorScalar.a();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenFloatTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||
// If a scalar number is converted to a 0-d tensor and passed on to
|
||||
// aten.Float.Tensor, fold to the scalar number.
|
||||
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||
return numToTensorScalar.a();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||
|
|
|
@ -169,7 +169,7 @@ static LogicalResult mungeFunction(
|
|||
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
|
||||
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
|
||||
op.emitError("Supported return types:"
|
||||
"mri1, mri32, mri64, mrf32, mrf64, i64, f32, f64,"
|
||||
"mri1, mri32, mri64, mrf32, mrf64, i1, i64, f32, f64,"
|
||||
"(mrf32, mri64), (mrf32, mrf32), (mrf64, mrf64),"
|
||||
"(mrf32, mrf32, mrf32)");
|
||||
isSupported = false;
|
||||
|
@ -195,6 +195,7 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
|
|||
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
|
||||
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
|
||||
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
|
||||
Type i1 = b.getI1Type();
|
||||
Type i64 = b.getI64Type();
|
||||
Type f32 = b.getF32Type();
|
||||
Type f64 = b.getF64Type();
|
||||
|
@ -204,6 +205,7 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
|
|||
mri64,
|
||||
mrf32,
|
||||
mrf64,
|
||||
i1,
|
||||
i64,
|
||||
f32,
|
||||
f64,
|
||||
|
|
|
@ -622,6 +622,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::IntImplicit : (Tensor) -> (int)")
|
||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
||||
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
||||
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||
emit("aten::t : (Tensor) -> (Tensor)")
|
||||
|
||||
|
|
|
@ -53,6 +53,10 @@ class RefBackendInvoker:
|
|||
def consume_return_mrf64(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.float64)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.c_bool)
|
||||
def consume_return_i1(a):
|
||||
self.result = a
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.c_int)
|
||||
def consume_return_i64(a):
|
||||
self.result = a
|
||||
|
@ -113,6 +117,9 @@ class RefBackendInvoker:
|
|||
self.ee.register_runtime("refbackend_consume_func_return_mrf64",
|
||||
consume_return_mrf64)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_func_return_i1",
|
||||
consume_return_i1)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_func_return_i64",
|
||||
consume_return_i64)
|
||||
|
||||
|
|
|
@ -57,16 +57,117 @@ func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @integer_extract
|
||||
// CHECK-SAME: (%[[A:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
|
||||
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[],si64> -> tensor<i64>
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i64>
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func @integer_extract(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
|
||||
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
// CHECK-LABEL: func @torch.aten.Int.Tensor$zero_rank
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
|
||||
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si64> -> tensor<i64>
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][] : tensor<i64>
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
|
||||
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.Int.Tensor$non_zero_rank
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.int {
|
||||
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = tensor.dim %[[I]], %[[C0]] : tensor<?x?xi64>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = tensor.dim %[[I]], %[[C1]] : tensor<?x?xi64>
|
||||
// CHECK: %[[ONE:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64
|
||||
// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64
|
||||
// CHECK: assert %[[PRED0]], "mismatching contracting dimension"
|
||||
// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64
|
||||
// CHECK: assert %[[PRED1]], "mismatching contracting dimension"
|
||||
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi64>
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
|
||||
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[?,?],si64> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.Float.Tensor$zero_rank
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float {
|
||||
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor<f64>
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][] : tensor<f64>
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
|
||||
// CHECK: return %[[RET]] : !torch.float
|
||||
func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float {
|
||||
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[],f64> -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.Float.Tensor$non_zero_rank
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.float {
|
||||
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f64> -> tensor<?x?xf64>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = tensor.dim %[[F]], %[[C0]] : tensor<?x?xf64>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = tensor.dim %[[F]], %[[C1]] : tensor<?x?xf64>
|
||||
// CHECK: %[[ONE:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64
|
||||
// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64
|
||||
// CHECK: assert %[[PRED0]], "mismatching contracting dimension"
|
||||
// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64
|
||||
// CHECK: assert %[[PRED1]], "mismatching contracting dimension"
|
||||
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xf64>
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
|
||||
// CHECK: return %[[RET]] : !torch.float
|
||||
func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float {
|
||||
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[?,?],f64> -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.Bool.Tensor$zero_rank
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],i1>) -> !torch.bool {
|
||||
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],i1> -> tensor<i1>
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1>
|
||||
// CHECK: %[[RES:.*]] = torch_c.from_i1 %[[EXT]]
|
||||
// CHECK: return %[[RES]] : !torch.bool
|
||||
func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool {
|
||||
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[],i1> -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.Bool.Tensor$non_zero_rank
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.bool {
|
||||
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xi1>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xi1>
|
||||
// CHECK: %[[ONE:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64
|
||||
// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64
|
||||
// CHECK: assert %[[PRED0]], "mismatching contracting dimension"
|
||||
// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64
|
||||
// CHECK: assert %[[PRED1]], "mismatching contracting dimension"
|
||||
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi1>
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_i1 %[[EXT]]
|
||||
// CHECK: return %[[RET]] : !torch.bool
|
||||
func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool {
|
||||
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[?,?],i1> -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
|
@ -721,6 +721,16 @@ func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
|
|||
return %scalar : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.Float.Tensor(
|
||||
// CHECK-SAME: %[[NUM:.*]]: !torch.float) -> !torch.float {
|
||||
// CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.float -> !torch.vtensor<[],f64>
|
||||
// CHECK: return %[[NUM]] : !torch.float
|
||||
func @torch.aten.Float.Tensor(%arg0: !torch.float) -> !torch.float {
|
||||
%tensor = torch.prim.NumToTensor.Scalar %arg0: !torch.float -> !torch.vtensor<[],f64>
|
||||
%scalar = torch.aten.Float.Tensor %tensor : !torch.vtensor<[],f64> -> !torch.float
|
||||
return %scalar : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.squeeze$zero_rank(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
|
||||
|
|
Loading…
Reference in New Issue