mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451, https://github.com/nod-ai/SHARK-Turbine/issues/452pull/2941/head
parent
3cbe6c98ec
commit
d81747eadb
|
@ -11231,6 +11231,7 @@ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
|
def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
|
||||||
|
@ -11254,6 +11255,7 @@ def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [
|
def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [
|
||||||
|
|
|
@ -1339,12 +1339,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
Value ratio, trainingMode;
|
Value ratio, trainingMode;
|
||||||
if (numOperands == 3) {
|
if (numOperands == 3) {
|
||||||
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
|
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
|
||||||
|
Value trainVal = operands[2];
|
||||||
|
auto trainTensorType =
|
||||||
|
trainVal.getType().dyn_cast<Torch::BaseTensorType>();
|
||||||
|
if (!trainTensorType)
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"train tensor must have a type");
|
||||||
|
|
||||||
|
Type inputDtype = trainTensorType.getOptionalDtype();
|
||||||
|
if (!inputDtype || !inputDtype.isInteger(1))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op,
|
||||||
|
"train tensor must have an integer dtype of width 1");
|
||||||
|
|
||||||
|
std::optional<unsigned> inputRank = Torch::getTensorRank(trainVal);
|
||||||
|
if (!inputRank || *inputRank != 0)
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"train tensor must have rank 0");
|
||||||
|
|
||||||
|
if (auto valueTensorLiteralOp =
|
||||||
|
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
|
||||||
|
auto val = valueTensorLiteralOp.getValue()
|
||||||
|
.cast<DenseElementsAttr>()
|
||||||
|
.getSplatValue<bool>();
|
||||||
|
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
|
||||||
|
} else {
|
||||||
Value trainingModeScalar =
|
Value trainingModeScalar =
|
||||||
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
|
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
|
||||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(1));
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
|
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
|
||||||
loc, trainingModeScalar, cstOne);
|
loc, trainingModeScalar, cstOne);
|
||||||
|
}
|
||||||
} else if (numOperands == 2) {
|
} else if (numOperands == 2) {
|
||||||
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
|
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
|
||||||
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
|
|
|
@ -191,12 +191,14 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenScalarImplicitOp
|
// Converts a tensor with one element to a scalar value.
|
||||||
: public OpConversionPattern<AtenScalarImplicitOp> {
|
template <typename OpTy>
|
||||||
|
class ConvertAtenImplicitLikeOp : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenScalarImplicitOp op, OpAdaptor adaptor,
|
matchAndRewrite(OpTy op,
|
||||||
|
typename OpConversionPattern<OpTy>::OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getA());
|
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getA());
|
||||||
return success();
|
return success();
|
||||||
|
@ -224,6 +226,12 @@ void mlir::torch::torch_to_linalg::
|
||||||
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
|
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
|
||||||
target.addIllegalOp<PrimNumToTensorScalarOp>();
|
target.addIllegalOp<PrimNumToTensorScalarOp>();
|
||||||
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
|
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
|
||||||
patterns.add<ConvertAtenScalarImplicitOp>(typeConverter, context);
|
patterns.add<ConvertAtenImplicitLikeOp<AtenScalarImplicitOp>>(typeConverter,
|
||||||
target.addIllegalOp<AtenScalarImplicitOp>();
|
context);
|
||||||
|
patterns.add<ConvertAtenImplicitLikeOp<AtenFloatImplicitOp>>(typeConverter,
|
||||||
|
context);
|
||||||
|
patterns.add<ConvertAtenImplicitLikeOp<AtenIntImplicitOp>>(typeConverter,
|
||||||
|
context);
|
||||||
|
target.addIllegalOp<AtenScalarImplicitOp, AtenFloatImplicitOp,
|
||||||
|
AtenIntImplicitOp>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -725,13 +725,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(div.getType())
|
Type dtype = converter->convertType(div.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
|
||||||
div.emitError("unimplemented: non-floating point dtype");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
if (dtype.isa<mlir::FloatType>())
|
||||||
return b.create<arith::DivFOp>(loc, lhs, rhs);
|
return b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||||
|
else if (dtype.isa<mlir::IntegerType>()) {
|
||||||
|
if (dtype.isUnsignedInteger())
|
||||||
|
return b.create<arith::DivUIOp>(loc, lhs, rhs);
|
||||||
|
return b.create<arith::DivSIOp>(loc, lhs, rhs);
|
||||||
|
}
|
||||||
|
div.emitError("unimplemented: non-floating point and non-integer dtype");
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
|
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
|
||||||
AtenDivTensorModeOp::Adaptor adaptor(operands);
|
AtenDivTensorModeOp::Adaptor adaptor(operands);
|
||||||
|
|
|
@ -1568,6 +1568,38 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenFloatImplicitOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
void AtenFloatImplicitOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &patterns, MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenFloatImplicitOp op, PatternRewriter &rewriter) {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value a = op.getA();
|
||||||
|
Value scalarValue = getScalarFloatValue(a, loc, rewriter);
|
||||||
|
if (!scalarValue)
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOp(op, scalarValue);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenIntImplicitOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
void AtenIntImplicitOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenIntImplicitOp op, PatternRewriter &rewriter) {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value a = op.getA();
|
||||||
|
Value scalarValue = getScalarIntValue(a, loc, rewriter);
|
||||||
|
if (!scalarValue)
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOp(op, scalarValue);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenSizeOp
|
// AtenSizeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -335,6 +335,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
|
|
||||||
# Dynamo not supporting conv_tbc
|
# Dynamo not supporting conv_tbc
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
|
|
||||||
|
"FloatImplicitModule_basic",
|
||||||
|
"IntImplicitModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCHDYNAMO_CRASHING_SET = {
|
TORCHDYNAMO_CRASHING_SET = {
|
||||||
|
@ -989,6 +992,8 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseCloneContiguousModule_basic",
|
"ElementwiseCloneContiguousModule_basic",
|
||||||
"ElementwiseCloneModule_basic",
|
"ElementwiseCloneModule_basic",
|
||||||
"ElementwiseDivScalarModule_basic",
|
"ElementwiseDivScalarModule_basic",
|
||||||
|
"ElementwiseDivTensorIntegerModule_basic",
|
||||||
|
"ElementwiseDivTensorUnsignedIntegerModule_basic",
|
||||||
"ElementwiseEluModule_basic",
|
"ElementwiseEluModule_basic",
|
||||||
"ElementwiseEluNonDefaultModule_basic",
|
"ElementwiseEluNonDefaultModule_basic",
|
||||||
"ElementwiseEqBoolScalarModule_basic",
|
"ElementwiseEqBoolScalarModule_basic",
|
||||||
|
@ -2146,8 +2151,6 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseSigmoidIntModule_basic",
|
"ElementwiseSigmoidIntModule_basic",
|
||||||
|
|
||||||
# Failure - unknown
|
# Failure - unknown
|
||||||
"ChunkListUnpackUneven_Module_basic",
|
|
||||||
"ChunkListUnpack_Module_basic",
|
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
"CopyWithDifferentDTypesAndSizesModule_basic",
|
"CopyWithDifferentDTypesAndSizesModule_basic",
|
||||||
"CopyWithDifferentDTypesModule_basic",
|
"CopyWithDifferentDTypesModule_basic",
|
||||||
|
@ -2168,6 +2171,8 @@ ONNX_XFAIL_SET = {
|
||||||
"ReduceMinAlongDimUnsignedInt_basic",
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
"TensorsStackNegativeDimModule_basic",
|
"TensorsStackNegativeDimModule_basic",
|
||||||
"TensorsStackPromoteDTypeModule_basic",
|
"TensorsStackPromoteDTypeModule_basic",
|
||||||
|
"FloatImplicitModule_basic",
|
||||||
|
"IntImplicitModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_CRASHING_SET = { }
|
ONNX_CRASHING_SET = { }
|
||||||
|
|
|
@ -669,8 +669,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)")
|
emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)")
|
||||||
emit("aten::IntImplicit : (Tensor) -> (int)")
|
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
|
||||||
emit("aten::FloatImplicit : (Tensor) -> (float)")
|
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
|
||||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
||||||
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
||||||
|
|
|
@ -3719,6 +3719,50 @@ def ScalarImplicitIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(low=-100, high=100))
|
module.forward(tu.randint(low=-100, high=100))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class FloatImplicitModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return float(torch.ops.aten.FloatImplicit(x))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: FloatImplicitModule())
|
||||||
|
def FloatImplicitModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand().double())
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class IntImplicitModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return float(torch.ops.aten.IntImplicit(x))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IntImplicitModule())
|
||||||
|
def IntImplicitModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint())
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class PowIntFloat(torch.nn.Module):
|
class PowIntFloat(torch.nn.Module):
|
||||||
|
|
|
@ -2595,6 +2595,52 @@ def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseDivTensorIntegerModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.div(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseDivTensorIntegerModule())
|
||||||
|
def ElementwiseDivTensorIntegerModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-10, high=10), tu.randint(3, 4, low=-10, high=10).type(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseDivTensorUnsignedIntegerModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.uint8, True),
|
||||||
|
([-1, -1], torch.uint8, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.div(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseDivTensorUnsignedIntegerModule())
|
||||||
|
def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=0, high=10).to(torch.uint8), tu.randint(3, 4, low=0, high=10).type(torch.uint8))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
|
class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -706,6 +706,15 @@ func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_div_int32
|
||||||
|
func.func @test_div_int32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],si32>
|
||||||
|
%0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_div_uint8
|
// CHECK-LABEL: @test_div_uint8
|
||||||
func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8>
|
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8>
|
||||||
|
|
|
@ -2145,6 +2145,52 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number
|
||||||
return %1 : !torch.number
|
return %1 : !torch.number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
|
||||||
|
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: return %[[FLOAT1]] : !torch.float
|
||||||
|
func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
|
||||||
|
%float1 = torch.constant.float 1.0
|
||||||
|
%0 = torch.prim.NumToTensor.Scalar %float1 : !torch.float -> !torch.vtensor<[],f64>
|
||||||
|
%1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float
|
||||||
|
return %1 : !torch.float
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float {
|
||||||
|
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: return %[[FLOAT1]] : !torch.float
|
||||||
|
func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float {
|
||||||
|
%0 = torch.vtensor.literal(dense<1.0> : tensor<f64>) : !torch.vtensor<[],f64>
|
||||||
|
%1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float
|
||||||
|
return %1 : !torch.float
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int {
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: return %[[INT1]] : !torch.int
|
||||||
|
func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
return %1 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int {
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: return %[[INT1]] : !torch.int
|
||||||
|
func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int {
|
||||||
|
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
%1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
return %1 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.prims.view_of$fold(
|
// CHECK-LABEL: func.func @torch.prims.view_of$fold(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
|
||||||
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>
|
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>
|
||||||
|
|
Loading…
Reference in New Issue