mirror of https://github.com/llvm/torch-mlir
[LINALG] Add E2E support for `aten.[Bool.Tensor|Float.Tensor]` op
- This commit adds lowering of `aten.Bool.Tensor` and `aten.Float.Tensor` op as a part of `convert-torch-to-linalg` pass. - It also adds support for returning bool types. - It also fixes lowering of the `aten.Int.Tensor` op for non-zero rank input tensors. - If a scalar number is converted to a 0-d tensor and passed on to the `aten.Float.Tensor` op, it folds to the scalar number. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/593/head
parent
9e7b6cab08
commit
f00d1686c8
|
@ -122,7 +122,6 @@ def AddmmModuleFloat_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AddmmModuleBroadcastable(torch.nn.Module):
|
class AddmmModuleBroadcastable(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -144,7 +143,6 @@ def AddmmModule_broadcastable(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AddmmModuleDifferentRankBroadcastable(torch.nn.Module):
|
class AddmmModuleDifferentRankBroadcastable(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -166,7 +164,6 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveAvgPool2dModule(torch.nn.Module):
|
class AdaptiveAvgPool2dModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -263,6 +260,7 @@ class MaxPool2dModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.mp2d(x)
|
return self.mp2d(x)
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: MaxPool2dModule())
|
@register_test_case(module_factory=lambda: MaxPool2dModule())
|
||||||
def MaxPool2dModule_basic(module, tu: TestUtils):
|
def MaxPool2dModule_basic(module, tu: TestUtils):
|
||||||
|
@ -328,6 +326,7 @@ class ConstantPadNdModule(torch.nn.Module):
|
||||||
def ConstantPadNdModule_basic(module, tu: TestUtils):
|
def ConstantPadNdModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
|
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ConstantPadNdStaticModule(torch.nn.Module):
|
class ConstantPadNdStaticModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -346,6 +345,8 @@ class ConstantPadNdStaticModule(torch.nn.Module):
|
||||||
def ConstantPadNdStaticModule_basic(module, tu: TestUtils):
|
def ConstantPadNdStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
|
module.forward(tu.rand(1, 1, 20, 20, 4, 4) - 0.5)
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ConstantPadNdPartialStaticModule(torch.nn.Module):
|
class ConstantPadNdPartialStaticModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -585,6 +586,8 @@ class SoftmaxIntModule(torch.nn.Module):
|
||||||
def SoftmaxIntModule_basic(module, tu: TestUtils):
|
def SoftmaxIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 2, 4))
|
module.forward(torch.randn(3, 2, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class _SoftmaxModule(torch.nn.Module):
|
class _SoftmaxModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -718,22 +721,7 @@ class ContiguousModule(torch.nn.Module):
|
||||||
def ContiguousModule_basic(module, tu: TestUtils):
|
def ContiguousModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 1))
|
module.forward(tu.rand(3, 1))
|
||||||
|
|
||||||
class TensorToInt(torch.nn.Module):
|
# ==============================================================================
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
([], torch.int64, True),
|
|
||||||
])
|
|
||||||
def forward(self, x):
|
|
||||||
return int(x)
|
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: TensorToInt())
|
|
||||||
def TensorToInt_basic(module, tu: TestUtils):
|
|
||||||
module.forward(torch.randint(10,[]))
|
|
||||||
|
|
||||||
class LogSoftmaxIntModule(torch.nn.Module):
|
class LogSoftmaxIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -752,6 +740,7 @@ class LogSoftmaxIntModule(torch.nn.Module):
|
||||||
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 2, 4).double())
|
module.forward(torch.randn(3, 2, 4).double())
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class NumToTensorIntModule(torch.nn.Module):
|
class NumToTensorIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -769,6 +758,7 @@ class NumToTensorIntModule(torch.nn.Module):
|
||||||
def NumToTensorIntModule_basic(module, tu: TestUtils):
|
def NumToTensorIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class NumToTensorFloatModule(torch.nn.Module):
|
class NumToTensorFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -808,6 +798,8 @@ class ReturnThreeTensorFloat32(torch.nn.Module):
|
||||||
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
|
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))
|
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class AddCMulModule(torch.nn.Module):
|
class AddCMulModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -827,6 +819,8 @@ class AddCMulModule(torch.nn.Module):
|
||||||
def AddCMulModule_basic(module, tu: TestUtils):
|
def AddCMulModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
|
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class AddCDivModule(torch.nn.Module):
|
class AddCDivModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -865,6 +859,8 @@ class tensorIntModule(torch.nn.Module):
|
||||||
def TensorIntModule_basic(module, tu: TestUtils):
|
def TensorIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class tensorFloatModule(torch.nn.Module):
|
class tensorFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -902,6 +898,7 @@ class DropoutModule(torch.nn.Module):
|
||||||
def DropoutModule_basic(module, tu: TestUtils):
|
def DropoutModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class MeanModule(torch.nn.Module):
|
class MeanModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -920,6 +917,7 @@ class MeanModule(torch.nn.Module):
|
||||||
def MeanModule_basic(module, tu: TestUtils):
|
def MeanModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 4))
|
module.forward(torch.randn(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class MeanDynamicSizesModule(torch.nn.Module):
|
class MeanDynamicSizesModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -938,6 +936,7 @@ class MeanDynamicSizesModule(torch.nn.Module):
|
||||||
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
|
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 4))
|
module.forward(torch.randn(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class NumelModule(torch.nn.Module):
|
class NumelModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -955,6 +954,7 @@ class NumelModule(torch.nn.Module):
|
||||||
def NumelModule_basic(module, tu: TestUtils):
|
def NumelModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 3, 5))
|
module.forward(tu.rand(4, 3, 5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class NumelZeroRankModule(torch.nn.Module):
|
class NumelZeroRankModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -972,6 +972,7 @@ class NumelZeroRankModule(torch.nn.Module):
|
||||||
def NumelZeroRankModule_basic(module, tu: TestUtils):
|
def NumelZeroRankModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(10,[]))
|
module.forward(torch.randint(10,[]))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class BoolTensorReturnFalseModule(torch.nn.Module):
|
class BoolTensorReturnFalseModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -990,6 +991,7 @@ class BoolTensorReturnFalseModule(torch.nn.Module):
|
||||||
def BoolTensorReturnFalseModule_basic(module, tu: TestUtils):
|
def BoolTensorReturnFalseModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.tensor([0, 0], dtype=torch.bool))
|
module.forward(torch.tensor([0, 0], dtype=torch.bool))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class BoolTensorReturnTrueModule(torch.nn.Module):
|
class BoolTensorReturnTrueModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -1008,6 +1010,7 @@ class BoolTensorReturnTrueModule(torch.nn.Module):
|
||||||
def BoolTensorReturnTrueModule_basic(module, tu: TestUtils):
|
def BoolTensorReturnTrueModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.tensor([1, 1, 1, 1, 1], dtype=torch.bool))
|
module.forward(torch.tensor([1, 1, 1, 1, 1], dtype=torch.bool))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class BoolTensorReturnMixedModule(torch.nn.Module):
|
class BoolTensorReturnMixedModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class TensorToIntZeroRank(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return int(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorToIntZeroRank())
|
||||||
|
def TensorToIntZeroRank_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, ()))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class TensorToInt(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return int(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorToInt())
|
||||||
|
def TensorToInt_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (1, 1)))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class TensorToFloatZeroRank(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return float(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorToFloatZeroRank())
|
||||||
|
def TensorToFloatZeroRank_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.rand((), dtype=torch.float64))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class TensorToFloat(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return float(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorToFloat())
|
||||||
|
def TensorToFloat_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.rand((1, 1), dtype=torch.float64))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class TensorToBoolZeroRank(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.bool, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return bool(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorToBoolZeroRank())
|
||||||
|
def TensorToBoolZeroRank_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor(1, dtype=torch.bool))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class TensorToBool(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.bool, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return bool(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorToBool())
|
||||||
|
def TensorToBool_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([[1]], dtype=torch.bool))
|
||||||
|
|
|
@ -51,6 +51,7 @@ from . import threshold
|
||||||
from . import histogram_binning_calibration
|
from . import histogram_binning_calibration
|
||||||
from . import table_batch_embedding
|
from . import table_batch_embedding
|
||||||
from . import rng
|
from . import rng
|
||||||
|
from . import cast
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||||
|
|
|
@ -2935,6 +2935,21 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::Float.Tensor : (Tensor) -> (float)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$a
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_FloatType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$a attr-dict `:` qualified(type($a)) `->` qualified(type($result))";
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
|
def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
|
|
@ -3804,30 +3804,41 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Casts a 0d integer tensor to elemental type.
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenIntTensorOp : public OpConversionPattern<AtenIntTensorOp> {
|
// Casts a tensor of exactly one element to an elemental type.
|
||||||
|
template <typename OpTy>
|
||||||
|
class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenIntTensorOp op, OpAdaptor adaptor,
|
matchAndRewrite(OpTy op,
|
||||||
|
typename OpConversionPattern<OpTy>::OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
Value intTensor = adaptor.a();
|
Location loc = op.getLoc();
|
||||||
auto tensorType = intTensor.getType().cast<RankedTensorType>();
|
Value input = adaptor.a();
|
||||||
|
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||||
|
int64_t inputRank = inputSizes.size();
|
||||||
|
|
||||||
if (tensorType.getRank() != 0)
|
// The `input` tensor must contain exactly one element, i.e., either the
|
||||||
return rewriter.notifyMatchFailure(
|
// `input` is a zero rank tensor or all the dimensions of the `input` tensor
|
||||||
op, "invalid rank: the rank of the input tensor must be 0");
|
// are unit.
|
||||||
|
Value constantOne =
|
||||||
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
for (int64_t i = 0; i < inputRank; i++)
|
||||||
|
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, intTensor);
|
// Extract the only element from the `input` tensor.
|
||||||
|
Value constantZero =
|
||||||
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||||
|
SmallVector<Value> indices(inputRank, constantZero);
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, input, indices);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenFill_ScalarOp : public OpConversionPattern<AtenFill_ScalarOp> {
|
class ConvertAtenFill_ScalarOp : public OpConversionPattern<AtenFill_ScalarOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -3853,7 +3864,6 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -4618,8 +4628,13 @@ public:
|
||||||
context);
|
context);
|
||||||
target.addIllegalOp<AtenContiguousOp>();
|
target.addIllegalOp<AtenContiguousOp>();
|
||||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenIntTensorOp>();
|
target.addIllegalOp<AtenIntTensorOp, AtenFloatTensorOp, AtenBoolTensorOp>();
|
||||||
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
|
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenIntTensorOp>>(
|
||||||
|
typeConverter, context);
|
||||||
|
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenFloatTensorOp>>(
|
||||||
|
typeConverter, context);
|
||||||
|
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenBoolTensorOp>>(
|
||||||
|
typeConverter, context);
|
||||||
target.addIllegalOp<PrimNumToTensorScalarOp>();
|
target.addIllegalOp<PrimNumToTensorScalarOp>();
|
||||||
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
|
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenDropoutOp>();
|
target.addIllegalOp<AtenDropoutOp>();
|
||||||
|
|
|
@ -1229,13 +1229,31 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenIntTensorOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// If an scalar number is converted to a 0-d tensor and passed on to
|
// If a scalar number is converted to a 0-d tensor and passed on to
|
||||||
// aten.Int.Tensor, fold to the scalar number.
|
// aten.Int.Tensor, fold to the scalar number.
|
||||||
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
|
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||||
return numToTensorScalar.a();
|
return numToTensorScalar.a();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenFloatTensorOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
// If a scalar number is converted to a 0-d tensor and passed on to
|
||||||
|
// aten.Float.Tensor, fold to the scalar number.
|
||||||
|
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||||
|
return numToTensorScalar.a();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||||
|
|
|
@ -169,7 +169,7 @@ static LogicalResult mungeFunction(
|
||||||
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
|
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
|
||||||
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
|
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
|
||||||
op.emitError("Supported return types:"
|
op.emitError("Supported return types:"
|
||||||
"mri1, mri32, mri64, mrf32, mrf64, i64, f32, f64,"
|
"mri1, mri32, mri64, mrf32, mrf64, i1, i64, f32, f64,"
|
||||||
"(mrf32, mri64), (mrf32, mrf32), (mrf64, mrf64),"
|
"(mrf32, mri64), (mrf32, mrf32), (mrf64, mrf64),"
|
||||||
"(mrf32, mrf32, mrf32)");
|
"(mrf32, mrf32, mrf32)");
|
||||||
isSupported = false;
|
isSupported = false;
|
||||||
|
@ -195,6 +195,7 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
|
||||||
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
|
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
|
||||||
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
|
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
|
||||||
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
|
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
|
||||||
|
Type i1 = b.getI1Type();
|
||||||
Type i64 = b.getI64Type();
|
Type i64 = b.getI64Type();
|
||||||
Type f32 = b.getF32Type();
|
Type f32 = b.getF32Type();
|
||||||
Type f64 = b.getF64Type();
|
Type f64 = b.getF64Type();
|
||||||
|
@ -204,6 +205,7 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
|
||||||
mri64,
|
mri64,
|
||||||
mrf32,
|
mrf32,
|
||||||
mrf64,
|
mrf64,
|
||||||
|
i1,
|
||||||
i64,
|
i64,
|
||||||
f32,
|
f32,
|
||||||
f64,
|
f64,
|
||||||
|
|
|
@ -622,6 +622,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::IntImplicit : (Tensor) -> (int)")
|
emit("aten::IntImplicit : (Tensor) -> (int)")
|
||||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
||||||
|
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
||||||
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||||
emit("aten::t : (Tensor) -> (Tensor)")
|
emit("aten::t : (Tensor) -> (Tensor)")
|
||||||
|
|
||||||
|
|
|
@ -53,6 +53,10 @@ class RefBackendInvoker:
|
||||||
def consume_return_mrf64(a):
|
def consume_return_mrf64(a):
|
||||||
self.result = unranked_memref_to_numpy(a, np.float64)
|
self.result = unranked_memref_to_numpy(a, np.float64)
|
||||||
|
|
||||||
|
@ctypes.CFUNCTYPE(None, ctypes.c_bool)
|
||||||
|
def consume_return_i1(a):
|
||||||
|
self.result = a
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.c_int)
|
@ctypes.CFUNCTYPE(None, ctypes.c_int)
|
||||||
def consume_return_i64(a):
|
def consume_return_i64(a):
|
||||||
self.result = a
|
self.result = a
|
||||||
|
@ -113,6 +117,9 @@ class RefBackendInvoker:
|
||||||
self.ee.register_runtime("refbackend_consume_func_return_mrf64",
|
self.ee.register_runtime("refbackend_consume_func_return_mrf64",
|
||||||
consume_return_mrf64)
|
consume_return_mrf64)
|
||||||
|
|
||||||
|
self.ee.register_runtime("refbackend_consume_func_return_i1",
|
||||||
|
consume_return_i1)
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_func_return_i64",
|
self.ee.register_runtime("refbackend_consume_func_return_i64",
|
||||||
consume_return_i64)
|
consume_return_i64)
|
||||||
|
|
||||||
|
|
|
@ -57,19 +57,120 @@ func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @integer_extract
|
// CHECK-LABEL: func @torch.aten.Int.Tensor$zero_rank
|
||||||
// CHECK-SAME: (%[[A:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
|
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
|
||||||
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[],si64> -> tensor<i64>
|
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si64> -> tensor<i64>
|
||||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i64>
|
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][] : tensor<i64>
|
||||||
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
|
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
|
||||||
// CHECK: return %[[RET]] : !torch.int
|
// CHECK: return %[[RET]] : !torch.int
|
||||||
func @integer_extract(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
|
func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
|
||||||
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
|
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
return %0 : !torch.int
|
return %0 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.Int.Tensor$non_zero_rank
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.int {
|
||||||
|
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[DIM0:.*]] = tensor.dim %[[I]], %[[C0]] : tensor<?x?xi64>
|
||||||
|
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[DIM1:.*]] = tensor.dim %[[I]], %[[C1]] : tensor<?x?xi64>
|
||||||
|
// CHECK: %[[ONE:.*]] = arith.constant 1 : i64
|
||||||
|
// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64
|
||||||
|
// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64
|
||||||
|
// CHECK: assert %[[PRED0]], "mismatching contracting dimension"
|
||||||
|
// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64
|
||||||
|
// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64
|
||||||
|
// CHECK: assert %[[PRED1]], "mismatching contracting dimension"
|
||||||
|
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi64>
|
||||||
|
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
|
||||||
|
// CHECK: return %[[RET]] : !torch.int
|
||||||
|
func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
|
||||||
|
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[?,?],si64> -> !torch.int
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.Float.Tensor$zero_rank
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float {
|
||||||
|
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor<f64>
|
||||||
|
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][] : tensor<f64>
|
||||||
|
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
|
||||||
|
// CHECK: return %[[RET]] : !torch.float
|
||||||
|
func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float {
|
||||||
|
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[],f64> -> !torch.float
|
||||||
|
return %0 : !torch.float
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.Float.Tensor$non_zero_rank
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.float {
|
||||||
|
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f64> -> tensor<?x?xf64>
|
||||||
|
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[DIM0:.*]] = tensor.dim %[[F]], %[[C0]] : tensor<?x?xf64>
|
||||||
|
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[DIM1:.*]] = tensor.dim %[[F]], %[[C1]] : tensor<?x?xf64>
|
||||||
|
// CHECK: %[[ONE:.*]] = arith.constant 1 : i64
|
||||||
|
// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64
|
||||||
|
// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64
|
||||||
|
// CHECK: assert %[[PRED0]], "mismatching contracting dimension"
|
||||||
|
// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64
|
||||||
|
// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64
|
||||||
|
// CHECK: assert %[[PRED1]], "mismatching contracting dimension"
|
||||||
|
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xf64>
|
||||||
|
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
|
||||||
|
// CHECK: return %[[RET]] : !torch.float
|
||||||
|
func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float {
|
||||||
|
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[?,?],f64> -> !torch.float
|
||||||
|
return %0 : !torch.float
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.Bool.Tensor$zero_rank
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],i1>) -> !torch.bool {
|
||||||
|
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],i1> -> tensor<i1>
|
||||||
|
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1>
|
||||||
|
// CHECK: %[[RES:.*]] = torch_c.from_i1 %[[EXT]]
|
||||||
|
// CHECK: return %[[RES]] : !torch.bool
|
||||||
|
func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool {
|
||||||
|
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[],i1> -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.Bool.Tensor$non_zero_rank
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.bool {
|
||||||
|
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xi1>
|
||||||
|
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xi1>
|
||||||
|
// CHECK: %[[ONE:.*]] = arith.constant 1 : i64
|
||||||
|
// CHECK: %[[DIM0_INDEX:.*]] = arith.index_cast %[[DIM0]] : index to i64
|
||||||
|
// CHECK: %[[PRED0:.*]] = arith.cmpi eq, %[[DIM0_INDEX]], %[[ONE]] : i64
|
||||||
|
// CHECK: assert %[[PRED0]], "mismatching contracting dimension"
|
||||||
|
// CHECK: %[[DIM1_INDEX:.*]] = arith.index_cast %[[DIM1]] : index to i64
|
||||||
|
// CHECK: %[[PRED1:.*]] = arith.cmpi eq, %[[DIM1_INDEX]], %[[ONE]] : i64
|
||||||
|
// CHECK: assert %[[PRED1]], "mismatching contracting dimension"
|
||||||
|
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi1>
|
||||||
|
// CHECK: %[[RET:.*]] = torch_c.from_i1 %[[EXT]]
|
||||||
|
// CHECK: return %[[RET]] : !torch.bool
|
||||||
|
func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool {
|
||||||
|
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[?,?],i1> -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
|
// CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
|
||||||
// CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]]
|
// CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]]
|
||||||
// CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64>
|
// CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64>
|
||||||
|
|
|
@ -721,6 +721,16 @@ func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
|
||||||
return %scalar : !torch.int
|
return %scalar : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.Float.Tensor(
|
||||||
|
// CHECK-SAME: %[[NUM:.*]]: !torch.float) -> !torch.float {
|
||||||
|
// CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.float -> !torch.vtensor<[],f64>
|
||||||
|
// CHECK: return %[[NUM]] : !torch.float
|
||||||
|
func @torch.aten.Float.Tensor(%arg0: !torch.float) -> !torch.float {
|
||||||
|
%tensor = torch.prim.NumToTensor.Scalar %arg0: !torch.float -> !torch.vtensor<[],f64>
|
||||||
|
%scalar = torch.aten.Float.Tensor %tensor : !torch.vtensor<[],f64> -> !torch.float
|
||||||
|
return %scalar : !torch.float
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.squeeze$zero_rank(
|
// CHECK-LABEL: func @torch.aten.squeeze$zero_rank(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
|
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
|
||||||
|
|
Loading…
Reference in New Issue