diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 37d3c1308..3bcbd0c7c 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -33,6 +33,8 @@ from . import conv from . import batchnorm from . import quantized_models from . import elementwise +from . import type_promotion +from . import type_conversion from . import reduction from . import argmax from . import matmul diff --git a/e2e_testing/torchscript/type_conversion.py b/e2e_testing/torchscript/type_conversion.py new file mode 100644 index 000000000..126ba1cb0 --- /dev/null +++ b/e2e_testing/torchscript/type_conversion.py @@ -0,0 +1,77 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + +class TypeConversionF32ToF64Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True) + ]) + def forward(self, x): + return x.to(torch.float64) + +@register_test_case(module_factory=lambda: TypeConversionF32ToF64Module()) +def TypeConversionF32ToF64Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) + + +class TypeConversionF64ToF32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float64, True) + ]) + def forward(self, x): + return x.to(torch.float32) + +@register_test_case(module_factory=lambda: TypeConversionF64ToF32Module()) +def TypeConversionF64ToF32Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5).type(torch.float64)) + +class TypeConversionI32ToI64Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True) + ]) + def forward(self, x): + return x.to(torch.int64) + +@register_test_case(module_factory=lambda: TypeConversionI32ToI64Module()) +def TypeConversionI32ToI64Module_basic(module, tu: TestUtils): + module.forward(torch.randint(5, [2, 3]).type(torch.int32)) + +class TypeConversionI64ToI32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True) + ]) + def forward(self, x): + return x.to(torch.int32) + +@register_test_case(module_factory=lambda: TypeConversionI64ToI32Module()) +def TypeConversionI64ToI32Module_basic(module, tu: TestUtils): + module.forward(torch.randint(5, [2, 3])) diff --git a/e2e_testing/torchscript/type_promotion.py b/e2e_testing/torchscript/type_promotion.py new file mode 100644 index 000000000..6cad4ef03 --- /dev/null +++ b/e2e_testing/torchscript/type_promotion.py @@ -0,0 +1,113 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + + +class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ([-1], torch.int64, True), + ]) + def forward(self, a, b): + return torch.add(a, b, alpha=3) + + +@register_test_case( + module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule()) +def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils): + module.forward( + torch.randint(10, [4]).type(torch.int32), + torch.randint(10, [4])) + + +class TypePromotionDifferentCategoryModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.add(a, b, alpha=3) + + +@register_test_case( + module_factory=lambda: TypePromotionDifferentCategoryModule()) +def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, [4]), torch.randn(4)) + + +class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([], torch.float64, True), + ]) + def forward(self, a, b): + return torch.add(a, b, alpha=2.3) + + +@register_test_case( + module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule()) +def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand().type(torch.float64)) + + +class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([], torch.float32, True), + ]) + def forward(self, a, b): + return torch.add(a, b, alpha=2) + + +@register_test_case( + module_factory=lambda: TypePromotionZeroRankHigherCategoryModule()) +def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, [4]), tu.rand()) + + +class TypePromotionAlphaWiderModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, a, b): + return torch.add(a, b, alpha=2.3) + + +@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule()) +def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand()) diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 96c7b6d5a..8adb6ce06 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1251,29 +1251,65 @@ public: }; } // namespace +// Convert a scalar value to the target type. The scalar value can be an element +// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype +// should be converted builtin types. +static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, + Type dtype) { + Type scalarType = scalar.getType(); + if (scalarType == dtype) + return scalar; -static Value promoteScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype) { - // TODO: For the integer case, we probably need the unconverted dtype to + // TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to // be able to know if we need signed or unsigned conversion. - if (dtype.isa()) { - if (scalar.getType().isa()) { - // `scalar` will always be f64 since that is what the TypeConverter - // converts !torch.float to. - return b.create(loc, scalar, dtype); - } else { - assert(scalar.getType().isa()); - // `scalar` will always be i64 since that is what the TypeConverter - // converts !torch.int to. - return b.create(loc, scalar, dtype); + auto isByteOrChar = [](Type type) { + if (auto integerTy = type.dyn_cast()) { + return integerTy.getWidth() == 8; } + return false; + }; + + if (isByteOrChar(scalarType) || isByteOrChar(dtype) || + scalarType.isSignlessInteger(1) || dtype.isSignlessInteger(1)) { + // TODO: Handle bool type. + mlir::emitError(loc) + << "unsupported byte, char or bool type for convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype << "(dtype)"; + return nullptr; } - mlir::emitError(loc) << "promoteScalarToDtype for dtype " << dtype; - return nullptr; + + if (auto dtypeFloat = dtype.dyn_cast()) { + if (auto scalarFloat = scalarType.dyn_cast()) { + if (scalarFloat.getWidth() > dtypeFloat.getWidth()) + return b.create(loc, scalar, dtype); + // Only scalarFloat width < dtypeFloat width can reach here. + return b.create(loc, scalar, dtype); + } + assert(scalarType.isa()); + // It's safe to use SIToFPOp because ui8/si8 are the only ones where + // unsigned handling is needed, and we checked for that case above. + return b.create(loc, scalar, dtype); + } + + if (auto dtypeInteger = dtype.dyn_cast()) { + if (auto scalarFloat = scalarType.dyn_cast()) + return b.create(loc, scalar, dtype); + assert(scalarType.isa()); + auto scalarInteger = scalarType.cast(); + if (scalarInteger.getWidth() > dtypeInteger.getWidth()) + return b.create(loc, scalar, dtype); + // Only scalarInteger width < dtypeInteger width can reach here. + // It's safe to use ExtSIOp here because ui8/si8 are the only ones where + // unsigned handling is needed, and we checked for that case above. + return b.create(loc, scalar, dtype); + } + + llvm_unreachable("convertScalarToDtype should handle all the types"); } static Value createLinalgPayloadCalculationForElementwiseOp( - OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op, - ArrayRef operands) { + OpBuilder &b, Location loc, TypeConverter *converter, + ValueRange payloadArgs, Operation *op, ArrayRef operands) { if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) @@ -1322,40 +1358,35 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); - if (add.alpha().getType().isa()) { - add.emitError("unimplemented: !torch.float 'alpha'"); - return nullptr; + Type dtype = converter->convertType(add.getType()) + .cast() + .getElementType(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype); + if (dtype.isa()) { + Value scaled = b.create(loc, rhs, alpha); + return b.create(loc, lhs, scaled); + } else { + Value scaled = b.create(loc, rhs, alpha); + return b.create(loc, lhs, scaled); } - if (!add.getType() - .cast() - .getDtype() - .isa()) { - add.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } - Value alphaFloat = b.create(loc, payloadArgs[0].getType(), - adaptor.alpha()); - Value scaled = b.create(loc, payloadArgs[1], alphaFloat); - return b.create(loc, payloadArgs[0], scaled); } if (auto sub = dyn_cast(op)) { AtenSubTensorOp::Adaptor adaptor(operands); - if (sub.alpha().getType().isa()) { - sub.emitError("unimplemented: !torch.float 'alpha'"); - return nullptr; + Type dtype = converter->convertType(sub.getType()) + .cast() + .getElementType(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype); + if (dtype.isa()) { + Value scaled = b.create(loc, rhs, alpha); + return b.create(loc, lhs, scaled); + } else { + Value scaled = b.create(loc, rhs, alpha); + return b.create(loc, lhs, scaled); } - if (!sub.getType() - .cast() - .getDtype() - .isa()) { - sub.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } - Value alphaFloat = b.create(loc, payloadArgs[0].getType(), - adaptor.alpha()); - Value scaled = b.create(loc, payloadArgs[1], alphaFloat); - - return b.create(loc, payloadArgs[0], scaled); } if (auto mul = dyn_cast(op)) { if (!mul.getType() @@ -1386,7 +1417,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } Type dtype = pow.self().getType().cast().getDtype(); - Value expPromoted = promoteScalarToDtype(b, loc, operands[1], dtype); + Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, payloadArgs[0], expPromoted); } if (auto lerp = dyn_cast(op)) { @@ -1430,7 +1461,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, pred, payloadArgs[0], payloadArgs[1]); } if (auto clamp = dyn_cast(op)) { - auto dtype = clamp.getType().cast().getDtype(); + Type dtype = converter->convertType(clamp.getType()) + .cast() + .getElementType(); if (!dtype.isa()) { clamp.emitError("unimplemented: non-floating point dtype"); return nullptr; @@ -1445,13 +1478,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } auto result = payloadArgs[0]; if (!min.getType().isa()) { - auto minPromoted = promoteScalarToDtype(b, loc, min, dtype); + auto minPromoted = convertScalarToDtype(b, loc, min, dtype); auto pred = b.create(loc, arith::CmpFPredicate::ULT, result, minPromoted); result = b.create(loc, pred, minPromoted, result); } if (!max.getType().isa()) { - auto maxPromoted = promoteScalarToDtype(b, loc, max, dtype); + auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); auto pred = b.create(loc, arith::CmpFPredicate::UGT, result, maxPromoted); result = b.create(loc, pred, maxPromoted, result); @@ -1459,36 +1492,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return result; } if (auto rsub = dyn_cast(op)) { - if (!rsub.getType() - .cast() - .getDtype() - .isa()) { + Type dtype = converter->convertType(rsub.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { rsub.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value self = payloadArgs[0]; - Value other = promoteScalarToDtype(b, loc, operands[1], self.getType()); - Value alpha = promoteScalarToDtype(b, loc, operands[2], self.getType()); + Value other = convertScalarToDtype(b, loc, operands[1], dtype); + Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } if (auto atenToDtype = dyn_cast(op)) { Value input = payloadArgs[0]; - Type inType = input.getType(); - Type outType = atenToDtype.getType().cast().getDtype(); - Value result; - if (!inType.isF32()) { - atenToDtype.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } - if (inType == outType) - result = input; - else if (outType.isInteger(64)) - result = b.create(loc, b.getI64Type(), input); - else if (outType.isInteger(1)) - result = b.create(loc, b.getI1Type(), input); - else - atenToDtype.emitError("unimplemented: unsupported target dtype"); + Type dtype = converter->convertType(atenToDtype.getType()) + .cast() + .getElementType(); + Value result = convertScalarToDtype(b, loc, input, dtype); return result; } @@ -1808,7 +1830,7 @@ struct ConvertElementwiseOp : ConversionPattern { /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = createLinalgPayloadCalculationForElementwiseOp( - b, loc, payloadArgs, op, operands); + b, loc, getTypeConverter(), payloadArgs, op, operands); if (!result) { hadErrorCreatingPayload = true; return; @@ -2161,7 +2183,7 @@ public: } SmallVector expectedSize = getTypeConvertedValues( rewriter, loc, typeConverter, expectedSizeTorchInt); - if (expectedSize.size() != resultRank) { + if (resultRank != (int64_t)expectedSize.size()) { return rewriter.notifyMatchFailure( op, "desired size list length mismatches with the result type rank"); } diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index c6296c2e0..7f8caea9f 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -229,8 +229,8 @@ public: AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, - AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, - AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op>(op)) { + AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, + AtenFloorOp, AtenLog2Op>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } @@ -285,7 +285,7 @@ public: return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands); } else if (isa(op)) { + AtenPowTensorScalarOp, AtenRsubScalarOp>(op)) { return visitBinaryTensorScalarOp(op, operands); } else if (isa()) { if (integerTy.isSignlessInteger(64)) return true; + if (integerTy.isSignlessInteger(32)) + return true; } } return false; @@ -109,7 +111,7 @@ static LogicalResult mungeFunction( auto type = arg.getType(); if (!isArgMemRefTypeValid(type)) return emitError(arg.getLoc(), - "argument must be a memref of f32, f64, i64"); + "argument must be a memref of f32, f64, i32, i64"); auto cast = b.create(arg.getLoc(), arg, type); arg.replaceAllUsesExcept(cast, cast); arg.setType(getAbiTypeForMemRef(type)); @@ -175,6 +177,8 @@ class MungeCallingConventions }; // Memref return types. + createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0), + "refbackend_consume_memref_int32_func_return"); createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0), "refbackend_consume_memref_int64_func_return"); createConsumeFuncReturnFunc( diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 2907e6647..46f27ef3e 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -24,7 +24,7 @@ __all__ = [ def checkArgTypeIsSupported(ty): - SUPPORTED = [np.float32, np.float64, np.int64] + SUPPORTED = [np.float32, np.float64, np.int32, np.int64] assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported" class RefBackendInvoker: @@ -32,6 +32,10 @@ class RefBackendInvoker: self.ee = ExecutionEngine(module) self.result = None + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) + def consume_memref_i32_return(a): + self.result = unranked_memref_to_numpy(a, np.int32) + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) def consume_memref_i64_return(a): self.result = unranked_memref_to_numpy(a, np.int64) @@ -56,6 +60,9 @@ class RefBackendInvoker: def consume_f64_return(a): self.result = a + self.ee.register_runtime("refbackend_consume_memref_int32_func_return", + consume_memref_i32_return) + self.ee.register_runtime("refbackend_consume_memref_int64_func_return", consume_memref_i64_return) diff --git a/python/torch_mlir_e2e_test/torchscript/reporting.py b/python/torch_mlir_e2e_test/torchscript/reporting.py index 75ec3e781..90600d0eb 100644 --- a/python/torch_mlir_e2e_test/torchscript/reporting.py +++ b/python/torch_mlir_e2e_test/torchscript/reporting.py @@ -20,9 +20,9 @@ from .framework import TestResult, TraceItem class TensorSummary: """A summary of a tensor's contents.""" def __init__(self, tensor): - self.min = torch.min(tensor) - self.max = torch.max(tensor) - self.mean = torch.mean(tensor) + self.min = torch.min(tensor.type(torch.float64)) + self.max = torch.max(tensor.type(torch.float64)) + self.mean = torch.mean(tensor.type(torch.float64)) self.shape = list(tensor.shape) def __str__(self): @@ -148,10 +148,15 @@ class ValueReport: if isinstance(golden, torch.Tensor): if not isinstance(value, torch.Tensor): return self._record_mismatch_type_failure('torch.Tensor', value) + if value.shape != golden.shape: return self._record_failure( f'shape ({value.shape}) is not equal to golden shape ({golden.shape})' ) + if value.dtype != golden.dtype: + return self._record_failure( + f'shape ({value.dtype}) is not equal to golden dtype ({golden.dtype})' + ) if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True): return self._record_failure( f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})' diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index e210ca936..ec11c38a6 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -980,12 +980,11 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: // ---- // CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix( -// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>, -// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -// CHECK-SAME: -> !torch.tensor { -// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor +// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>, +// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { +// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor +// CHECK: return %[[CAST]] : !torch.tensor func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor return %0 : !torch.tensor @@ -994,27 +993,25 @@ func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, // ---- // CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector( -// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>, -// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -// CHECK-SAME: -> !torch.tensor { -// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor +// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>, +// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -> !torch.tensor { +// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor +// CHECK: return %[[CAST]] : !torch.tensor func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor return %0 : !torch.tensor } // ----- -// CHECK-LABEL: func @torch.aten.to.dtype -// CHECK-SAME: (%[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor +// CHECK-LABEL: func @torch.aten.to.dtype( +// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor // CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype // CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : // CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none // CHECK-SAME: -> !torch.tensor<[?,?],si64> // CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<[?,?],si64> to !torch.tensor // CHECK-NEXT: return %[[RES]] : !torch.tensor - func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{ %none = torch.constant.none %false = torch.constant.bool false @@ -1025,12 +1022,10 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{ // ---- // CHECK-LABEL: func @torch.prim.NumToTensor.Scalar( -// CHECK-SAME: %[[SELF:.*]]: !torch.int) -// CHECK-SAME: -> !torch.tensor { -// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor - +// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor { +// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor +// CHECK: return %[[CAST]] : !torch.tensor func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor { %0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor return %0: !torch.tensor