Add convertScalarToDtype helper.

This is to facilitate scalar type conversion in the TorchToLinalg. As
part of adding the helper, this PR also:
- Updated `AtenAddTensorOp`, `AtenSubTensorOp` to use the helpers to
support more type variants.
- Added e2e type promotion testing.
- Added i32 memref return/arg type to support e2e testing.
pull/402/head
Yi Zhang 2021-10-15 18:23:40 -04:00
parent e23cabf3a9
commit 05c4dd8e39
9 changed files with 326 additions and 101 deletions

View File

@ -33,6 +33,8 @@ from . import conv
from . import batchnorm
from . import quantized_models
from . import elementwise
from . import type_promotion
from . import type_conversion
from . import reduction
from . import argmax
from . import matmul

View File

@ -0,0 +1,77 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class TypeConversionF32ToF64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True)
])
def forward(self, x):
return x.to(torch.float64)
@register_test_case(module_factory=lambda: TypeConversionF32ToF64Module())
def TypeConversionF32ToF64Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class TypeConversionF64ToF32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float64, True)
])
def forward(self, x):
return x.to(torch.float32)
@register_test_case(module_factory=lambda: TypeConversionF64ToF32Module())
def TypeConversionF64ToF32Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5).type(torch.float64))
class TypeConversionI32ToI64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True)
])
def forward(self, x):
return x.to(torch.int64)
@register_test_case(module_factory=lambda: TypeConversionI32ToI64Module())
def TypeConversionI32ToI64Module_basic(module, tu: TestUtils):
module.forward(torch.randint(5, [2, 3]).type(torch.int32))
class TypeConversionI64ToI32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True)
])
def forward(self, x):
return x.to(torch.int32)
@register_test_case(module_factory=lambda: TypeConversionI64ToI32Module())
def TypeConversionI64ToI32Module_basic(module, tu: TestUtils):
module.forward(torch.randint(5, [2, 3]))

View File

@ -0,0 +1,113 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int32, True),
([-1], torch.int64, True),
])
def forward(self, a, b):
return torch.add(a, b, alpha=3)
@register_test_case(
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule())
def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils):
module.forward(
torch.randint(10, [4]).type(torch.int32),
torch.randint(10, [4]))
class TypePromotionDifferentCategoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
([-1], torch.float32, True),
])
def forward(self, a, b):
return torch.add(a, b, alpha=3)
@register_test_case(
module_factory=lambda: TypePromotionDifferentCategoryModule())
def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, [4]), torch.randn(4))
class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([], torch.float64, True),
])
def forward(self, a, b):
return torch.add(a, b, alpha=2.3)
@register_test_case(
module_factory=lambda: TypePromotionSameCategoryZeroRankWiderModule())
def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand().type(torch.float64))
class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
([], torch.float32, True),
])
def forward(self, a, b):
return torch.add(a, b, alpha=2)
@register_test_case(
module_factory=lambda: TypePromotionZeroRankHigherCategoryModule())
def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, [4]), tu.rand())
class TypePromotionAlphaWiderModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([], torch.float32, True),
])
def forward(self, a, b):
return torch.add(a, b, alpha=2.3)
@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule())
def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand())

View File

@ -1251,29 +1251,65 @@ public:
};
} // namespace
// Convert a scalar value to the target type. The scalar value can be an element
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
Type dtype) {
Type scalarType = scalar.getType();
if (scalarType == dtype)
return scalar;
static Value promoteScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype) {
// TODO: For the integer case, we probably need the unconverted dtype to
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
// be able to know if we need signed or unsigned conversion.
if (dtype.isa<mlir::FloatType>()) {
if (scalar.getType().isa<mlir::FloatType>()) {
// `scalar` will always be f64 since that is what the TypeConverter
// converts !torch.float to.
return b.create<arith::TruncFOp>(loc, scalar, dtype);
} else {
assert(scalar.getType().isa<mlir::IntegerType>());
// `scalar` will always be i64 since that is what the TypeConverter
// converts !torch.int to.
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
auto isByteOrChar = [](Type type) {
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
return integerTy.getWidth() == 8;
}
return false;
};
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
scalarType.isSignlessInteger(1) || dtype.isSignlessInteger(1)) {
// TODO: Handle bool type.
mlir::emitError(loc)
<< "unsupported byte, char or bool type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
}
mlir::emitError(loc) << "promoteScalarToDtype for dtype " << dtype;
return nullptr;
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
return b.create<arith::TruncFOp>(loc, scalar, dtype);
// Only scalarFloat width < dtypeFloat width can reach here.
return b.create<arith::ExtFOp>(loc, scalar, dtype);
}
assert(scalarType.isa<mlir::IntegerType>());
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
}
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
return b.create<arith::FPToSIOp>(loc, scalar, dtype);
assert(scalarType.isa<mlir::IntegerType>());
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
return b.create<arith::TruncIOp>(loc, scalar, dtype);
// Only scalarInteger width < dtypeInteger width can reach here.
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::ExtSIOp>(loc, scalar, dtype);
}
llvm_unreachable("convertScalarToDtype should handle all the types");
}
static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
ArrayRef<Value> operands) {
OpBuilder &b, Location loc, TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
if (isa<AtenTanhOp>(op))
return b.create<math::TanhOp>(loc, payloadArgs[0]);
if (isa<AtenExpOp>(op))
@ -1322,40 +1358,35 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
AtenAddTensorOp::Adaptor adaptor(operands);
if (add.alpha().getType().isa<Torch::FloatType>()) {
add.emitError("unimplemented: !torch.float 'alpha'");
return nullptr;
Type dtype = converter->convertType(add.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype);
if (dtype.isa<mlir::FloatType>()) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::AddFOp>(loc, lhs, scaled);
} else {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::AddIOp>(loc, lhs, scaled);
}
if (!add.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
add.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value alphaFloat = b.create<arith::SIToFPOp>(loc, payloadArgs[0].getType(),
adaptor.alpha());
Value scaled = b.create<arith::MulFOp>(loc, payloadArgs[1], alphaFloat);
return b.create<arith::AddFOp>(loc, payloadArgs[0], scaled);
}
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
AtenSubTensorOp::Adaptor adaptor(operands);
if (sub.alpha().getType().isa<Torch::FloatType>()) {
sub.emitError("unimplemented: !torch.float 'alpha'");
return nullptr;
Type dtype = converter->convertType(sub.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype);
if (dtype.isa<mlir::FloatType>()) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::SubFOp>(loc, lhs, scaled);
} else {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::SubIOp>(loc, lhs, scaled);
}
if (!sub.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
sub.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value alphaFloat = b.create<arith::SIToFPOp>(loc, payloadArgs[0].getType(),
adaptor.alpha());
Value scaled = b.create<arith::MulFOp>(loc, payloadArgs[1], alphaFloat);
return b.create<arith::SubFOp>(loc, payloadArgs[0], scaled);
}
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
if (!mul.getType()
@ -1386,7 +1417,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr;
}
Type dtype = pow.self().getType().cast<ValueTensorType>().getDtype();
Value expPromoted = promoteScalarToDtype(b, loc, operands[1], dtype);
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
}
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
@ -1430,7 +1461,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
}
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
auto dtype = clamp.getType().cast<ValueTensorType>().getDtype();
Type dtype = converter->convertType(clamp.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
clamp.emitError("unimplemented: non-floating point dtype");
return nullptr;
@ -1445,13 +1478,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>()) {
auto minPromoted = promoteScalarToDtype(b, loc, min, dtype);
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
result, minPromoted);
result = b.create<SelectOp>(loc, pred, minPromoted, result);
}
if (!max.getType().isa<Torch::NoneType>()) {
auto maxPromoted = promoteScalarToDtype(b, loc, max, dtype);
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
result, maxPromoted);
result = b.create<SelectOp>(loc, pred, maxPromoted, result);
@ -1459,36 +1492,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result;
}
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
if (!rsub.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
Type dtype = converter->convertType(rsub.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
rsub.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value self = payloadArgs[0];
Value other = promoteScalarToDtype(b, loc, operands[1], self.getType());
Value alpha = promoteScalarToDtype(b, loc, operands[2], self.getType());
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
}
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
Value input = payloadArgs[0];
Type inType = input.getType();
Type outType = atenToDtype.getType().cast<ValueTensorType>().getDtype();
Value result;
if (!inType.isF32()) {
atenToDtype.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
if (inType == outType)
result = input;
else if (outType.isInteger(64))
result = b.create<arith::FPToSIOp>(loc, b.getI64Type(), input);
else if (outType.isInteger(1))
result = b.create<arith::FPToSIOp>(loc, b.getI1Type(), input);
else
atenToDtype.emitError("unimplemented: unsupported target dtype");
Type dtype = converter->convertType(atenToDtype.getType())
.cast<RankedTensorType>()
.getElementType();
Value result = convertScalarToDtype(b, loc, input, dtype);
return result;
}
@ -1808,7 +1830,7 @@ struct ConvertElementwiseOp : ConversionPattern {
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = createLinalgPayloadCalculationForElementwiseOp(
b, loc, payloadArgs, op, operands);
b, loc, getTypeConverter(), payloadArgs, op, operands);
if (!result) {
hadErrorCreatingPayload = true;
return;
@ -2161,7 +2183,7 @@ public:
}
SmallVector<Value> expectedSize = getTypeConvertedValues(
rewriter, loc, typeConverter, expectedSizeTorchInt);
if (expectedSize.size() != resultRank) {
if (resultRank != (int64_t)expectedSize.size()) {
return rewriter.notifyMatchFailure(
op, "desired size list length mismatches with the result type rank");
}

View File

@ -229,8 +229,8 @@ public:
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op>(op)) {
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenLog2Op>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}
@ -285,7 +285,7 @@ public:
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
AtenPowTensorScalarOp>(op)) {
AtenPowTensorScalarOp, AtenRsubScalarOp>(op)) {
return visitBinaryTensorScalarOp(op, operands);
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,

View File

@ -51,6 +51,8 @@ static bool isArgMemRefTypeValid(Type type) {
} else if (auto integerTy = elemTy.dyn_cast<IntegerType>()) {
if (integerTy.isSignlessInteger(64))
return true;
if (integerTy.isSignlessInteger(32))
return true;
}
}
return false;
@ -109,7 +111,7 @@ static LogicalResult mungeFunction(
auto type = arg.getType();
if (!isArgMemRefTypeValid(type))
return emitError(arg.getLoc(),
"argument must be a memref of f32, f64, i64");
"argument must be a memref of f32, f64, i32, i64");
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
arg.replaceAllUsesExcept(cast, cast);
arg.setType(getAbiTypeForMemRef(type));
@ -175,6 +177,8 @@ class MungeCallingConventions
};
// Memref return types.
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0),
"refbackend_consume_memref_int32_func_return");
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
"refbackend_consume_memref_int64_func_return");
createConsumeFuncReturnFunc(

View File

@ -24,7 +24,7 @@ __all__ = [
def checkArgTypeIsSupported(ty):
SUPPORTED = [np.float32, np.float64, np.int64]
SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
class RefBackendInvoker:
@ -32,6 +32,10 @@ class RefBackendInvoker:
self.ee = ExecutionEngine(module)
self.result = None
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_i32_return(a):
self.result = unranked_memref_to_numpy(a, np.int32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_i64_return(a):
self.result = unranked_memref_to_numpy(a, np.int64)
@ -56,6 +60,9 @@ class RefBackendInvoker:
def consume_f64_return(a):
self.result = a
self.ee.register_runtime("refbackend_consume_memref_int32_func_return",
consume_memref_i32_return)
self.ee.register_runtime("refbackend_consume_memref_int64_func_return",
consume_memref_i64_return)

View File

@ -20,9 +20,9 @@ from .framework import TestResult, TraceItem
class TensorSummary:
"""A summary of a tensor's contents."""
def __init__(self, tensor):
self.min = torch.min(tensor)
self.max = torch.max(tensor)
self.mean = torch.mean(tensor)
self.min = torch.min(tensor.type(torch.float64))
self.max = torch.max(tensor.type(torch.float64))
self.mean = torch.mean(tensor.type(torch.float64))
self.shape = list(tensor.shape)
def __str__(self):
@ -148,10 +148,15 @@ class ValueReport:
if isinstance(golden, torch.Tensor):
if not isinstance(value, torch.Tensor):
return self._record_mismatch_type_failure('torch.Tensor', value)
if value.shape != golden.shape:
return self._record_failure(
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
)
if value.dtype != golden.dtype:
return self._record_failure(
f'shape ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
)
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
return self._record_failure(
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'

View File

@ -980,12 +980,11 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
// ----
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
// CHECK-SAME: -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
@ -994,27 +993,25 @@ func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>,
// ----
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
// CHECK-SAME: -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.to.dtype
// CHECK-SAME: (%[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
// CHECK-LABEL: func @torch.aten.to.dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
// CHECK-SAME: -> !torch.tensor<[?,?],si64>
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<[?,?],si64> to !torch.tensor
// CHECK-NEXT: return %[[RES]] : !torch.tensor
func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
%none = torch.constant.none
%false = torch.constant.bool false
@ -1025,12 +1022,10 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
// ----
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
// CHECK-SAME: %[[SELF:.*]]: !torch.int)
// CHECK-SAME: -> !torch.tensor {
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
return %0: !torch.tensor