Elide `!torch.` prefix in nested dialect types.

This leads to much more succinct types in many cases:

```
!torch.list<!torch.int>
!torch.list<int>

!torch.tuple<!torch.list<!torch.int>, !torch.list<!torch.int>>
!torch.tuple<list<int>, list<int>>

!torch.optional<!torch.list<!torch.int>>
!torch.optional<list<int>>

!torch.list<list<list<tensor>>>
!torch.list<!torch.list<!torch.list<!torch.tensor>>>
```

I would like to take this further and allow omitting the `!torch.`
prefix in all cases, but that's harder -- for example, we currently use
`FuncOp` for functions, and so I don't think we can customize the
printing there. It seems like it will be a longer road to getting that
level of customization.
pull/669/head
Sean Silva 2022-03-15 23:22:56 +00:00
parent 3734f69119
commit 84a9693006
36 changed files with 1567 additions and 1477 deletions

View File

@ -36,6 +36,14 @@ def Torch_Dialect : Dialect {
let hasRegionArgAttrVerify = 1;
let hasConstantMaterializer = 1;
let useDefaultTypePrinterParser = 0;
let extraClassDeclaration = [{
/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &printer) const override;
}];
}
class TorchOpTrait<string name> : NativeOpTrait<""> {

View File

@ -14,4 +14,18 @@
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h.inc"
namespace mlir {
namespace torch {
namespace Torch {
/// Parse a type registered to this dialect.
Type parseTorchDialectType(AsmParser &parser);
/// Print a type registered to this dialect.
void printTorchDialectType(Type type, AsmPrinter &printer);
} // namespace Torch
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H

View File

@ -1060,7 +1060,7 @@ def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [
(in the region `shapeCalculation`) which calculates the shapes for
the set of values yielded by the `body` region.
The `shapeCalculation` region yields a `!torch.list<!torch.int>` for each
The `shapeCalculation` region yields a `!torch.list<int>` for each
value yielded by the `body` region.
Conceptually, the `shapeCalculation` region executes first, then `body`

View File

@ -26,14 +26,21 @@ class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type
let parameters = (ins "::mlir::Type":$containedType);
let printer = [{
$_printer << "<" << getImpl()->containedType << ">";
$_printer << "<";
// Print the contained type without the `!torch.` prefix.
printTorchDialectType(getImpl()->containedType, $_printer);
$_printer << ">";
}];
let parser = [{
if ($_parser.parseLess())
return Type();
Type containedType;
if ($_parser.parseType(containedType))
// Parse the contained type, but forward directly to our internal parsing
// of `torch` dialect types, so that we can parse nested types without
// the `!torch.` prefix.
Type containedType = parseTorchDialectType($_parser);
if (!containedType)
return Type();
if ($_parser.parseGreater())
return Type();
@ -344,19 +351,23 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
}];
let printer = [{
$_printer << "<" << getImpl()->keyType << ", " << getImpl()->valueType << ">";
$_printer << "<";
printTorchDialectType(getImpl()->keyType, $_printer);
$_printer << ", ";
printTorchDialectType(getImpl()->valueType, $_printer);
$_printer<< ">";
}];
let parser = [{
if ($_parser.parseLess())
return Type();
Type keyType;
if ($_parser.parseType(keyType))
Type keyType = parseTorchDialectType($_parser);
if (!keyType)
return Type();
if ($_parser.parseComma())
return Type();
Type valueType;
if ($_parser.parseType(valueType))
Type valueType = parseTorchDialectType($_parser);
if (!valueType)
return Type();
if ($_parser.parseGreater())
return Type();

View File

@ -47,6 +47,58 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
#define GET_TYPEDEF_CLASSES
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
//===----------------------------------------------------------------------===//
// Top-level parsing/printing of types for TorchDialect.
//===----------------------------------------------------------------------===//
//
// Unfortunately, TorchDialect::parseType/printType are non-static member
// functions, even though they don't depend on any instance state of the
// dialect. This is problematic, for example, when wanting to call these
// functions directly from type printers/parsers.
//
// So define some helpers that are free functions.
/// Parse a type registered to this dialect.
Type Torch::parseTorchDialectType(AsmParser &parser) {
SMLoc typeLoc = parser.getCurrentLocation();
StringRef mnemonic;
if (parser.parseOptionalKeyword(&mnemonic)) {
parser.emitError(parser.getCurrentLocation())
.append("expected type mnemonic")
.attachNote()
.append("for types like `!torch.list<int>`, you must omit the "
"`!torch.` prefix for the nested types");
return Type();
}
Type genType;
auto parseResult = generatedTypeParser(parser, mnemonic, genType);
if (parseResult.hasValue())
return genType;
parser.emitError(typeLoc) << "unknown type `" << mnemonic << "` in dialect `"
<< TorchDialect::getDialectNamespace() << "`";
return {};
}
/// Print a type registered to this dialect.
void Torch::printTorchDialectType(Type type, AsmPrinter &printer) {
if (succeeded(generatedTypePrinter(type, printer)))
return;
}
//===----------------------------------------------------------------------===//
// Torch dialect parseType/printType methods.
//===----------------------------------------------------------------------===//
/// Parse a type registered to this dialect.
Type TorchDialect::parseType(DialectAsmParser &parser) const {
return parseTorchDialectType(parser);
}
/// Print a type registered to this dialect.
void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
printTorchDialectType(type, printer);
}
//===----------------------------------------------------------------------===//
// Dialect initialize method.
//===----------------------------------------------------------------------===//

View File

@ -30,8 +30,8 @@ Type Torch::TupleType::parse(AsmParser &parser) {
SmallVector<Type> containedTypes;
do {
Type containedType;
if (parser.parseType(containedType))
Type containedType = parseTorchDialectType(parser);
if (!containedType)
return Type();
containedTypes.push_back(containedType);
} while (!parser.parseOptionalComma());
@ -42,7 +42,9 @@ Type Torch::TupleType::parse(AsmParser &parser) {
void Torch::TupleType::print(::mlir::AsmPrinter &printer) const {
printer << "<";
llvm::interleaveComma(getContainedTypes(), printer);
llvm::interleaveComma(getContainedTypes(), printer, [&](Type type) {
printTorchDialectType(type, printer);
});
printer << ">";
}

View File

@ -99,8 +99,8 @@ static Value adjustShapeFunctionArg(Value operand, Type desiredType,
// For the non-None case, we need to unwrap the optional type and then adjust
// it recursively (which also takes care of derefining it to ultimate desired
// type).
// A case where this happens is `!torch.optional<!torch.vtensor>` ->
// `!torch.optional<!torch.list<!torch.int>>>`.
// A case where this happens is `!torch.optional<vtensor>` ->
// `!torch.optional<list<int>>>`.
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) {
if (desiredType.isa<Torch::OptionalType>()) {
// if optional is None:
@ -131,7 +131,7 @@ static Value adjustShapeFunctionArg(Value operand, Type desiredType,
// If the desired type is OptionalType, then recursively adjust the operand to
// the contained type, then derefine it to `!torch.optional`. For example,
// `!torch.vtensor -> !torch.optional<!torch.list<!torch.int>>>`.
// `!torch.vtensor -> !torch.optional<list<int>>>`.
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) {
auto adjusted = adjustShapeFunctionArg(
operand, desiredOptionalType.getContainedType(), b, loc);
@ -139,7 +139,7 @@ static Value adjustShapeFunctionArg(Value operand, Type desiredType,
}
// The shape library functions have tensor operands replaced with
// `!torch.list<!torch.int>` types for the shape. Get the sizes.
// `!torch.list<int>` types for the shape. Get the sizes.
if (operand.getType().isa<Torch::BaseTensorType>()) {
assert(desiredType.isa<Torch::ListType>() &&
"Don't expect shape functions to have tensor parameters");
@ -147,7 +147,7 @@ static Value adjustShapeFunctionArg(Value operand, Type desiredType,
}
// Run this after `operand.getType().isa<Torch::BaseTensorType>()` so that
// `!torch.vtensor` -> `!torch.list<!torch.int>` is handled there specially
// `!torch.vtensor` -> `!torch.list<int>` is handled there specially
// first.
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) {
return adjustListArg(operand, desiredListType, b, loc);

File diff suppressed because it is too large Load Diff

View File

@ -18,10 +18,10 @@ builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[C1]], %[[C2]]] : tensor<?x?xf32>
// CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?x?xf32>, tensor<?x?xf32>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
%kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%dilation = torch.prim.ListConstruct %int7, %int8 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%4 = torch.aten.max_pool2d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list<int>
%dilation = torch.prim.ListConstruct %int7, %int8 : (!torch.int, !torch.int) -> !torch.list<int>
%4 = torch.aten.max_pool2d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?,?],f32>
}

View File

@ -208,7 +208,7 @@ func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
@ -219,10 +219,10 @@ func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%dim0 = torch.constant.int 0
%reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<!torch.int>
%reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<int>
%keepdims = torch.constant.bool false
%dtype = torch.constant.none
%0 = torch.aten.mean.dim %arg0, %reducedims, %keepdims, %dtype : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
%0 = torch.aten.mean.dim %arg0, %reducedims, %keepdims, %dtype : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
@ -234,7 +234,7 @@ func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none
// CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false
// CHECK: %[[ARG3:.*]] = torch.constant.int 0
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
@ -243,8 +243,8 @@ func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
%none = torch.constant.none
%false = torch.constant.bool false
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
%0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
%1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
return %1 : !torch.vtensor<[?,?,?],f32>
}
@ -478,15 +478,15 @@ func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [-1]} : (tensor<?x?x?x?xf32>) -> tensor<?xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32>
// CHECK: }
func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
%dim0 = torch.constant.int -1
%shape = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<!torch.int>
%0 = torch.aten.reshape %arg0, %shape : !torch.vtensor<[?,?,?,?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[?],f32>
%shape = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<int>
%0 = torch.aten.reshape %arg0, %shape : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32>
}
@ -557,7 +557,7 @@ func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,
// CHECK: %[[VAL_6:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[VAL_7:.*]] = torch.constant.int 3
// CHECK: %[[VAL_8:.*]] = torch.constant.int 2
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() {value = dense<1.200000e+01> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK: %[[VAL_11:.*]] = "tosa.reciprocal"(%[[VAL_10]]) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
@ -588,8 +588,8 @@ func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,
%float5.000000e-01 = torch.constant.float 5.000000e-01
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int2, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2, %float5.000000e-01 : !torch.vtensor<[5,2,2,3],f32>, !torch.list<!torch.int>, !torch.vtensor<[2,2,3],f32>, !torch.vtensor<[2,2,3],f32>, !torch.float -> !torch.vtensor<[5,2,2,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
%0 = torch.prim.ListConstruct %int2, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2, %float5.000000e-01 : !torch.vtensor<[5,2,2,3],f32>, !torch.list<int>, !torch.vtensor<[2,2,3],f32>, !torch.vtensor<[2,2,3],f32>, !torch.float -> !torch.vtensor<[5,2,2,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
return %result0 : !torch.vtensor<[5,2,2,3],f32>
}
@ -618,7 +618,7 @@ func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
// CHECK: %[[VAL_7:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_6]]) : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
@ -628,8 +628,8 @@ func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %int0, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[3,4,2],f32>, !torch.list<!torch.int> -> !torch.vtensor<[3,2,4],f32>
%0 = torch.prim.ListConstruct %int0, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[3,4,2],f32>, !torch.list<int> -> !torch.vtensor<[3,2,4],f32>
return %1 : !torch.vtensor<[3,2,4],f32>
}
@ -672,7 +672,7 @@ func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<0> : tensor<3x4xi32>} : () -> tensor<3x4xi32>
// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
@ -682,8 +682,8 @@ func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
return %1 : !torch.vtensor<[3,4],f32>
}
@ -693,7 +693,7 @@ func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<3x4xi32>} : () -> tensor<3x4xi32>
// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
@ -703,7 +703,7 @@ func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
return %1 : !torch.vtensor<[3,4],f32>
}

View File

@ -2,20 +2,20 @@
// CHECK that multiple nested initialization ops are properly handled.
// CHECK-LABEL: torch.global_slot @l : !torch.list<!torch.list<!torch.list<!torch.tensor>>> {
// CHECK: %[[L0:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.tensor>
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[L0]], %[[L0]] : (!torch.list<!torch.tensor>, !torch.list<!torch.tensor>) -> !torch.list<!torch.list<!torch.tensor>>
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[L1]], %[[L1]] : (!torch.list<!torch.list<!torch.tensor>>, !torch.list<!torch.list<!torch.tensor>>) -> !torch.list<!torch.list<!torch.list<!torch.tensor>>>
// CHECK: torch.global_slot.init %[[L2]] : !torch.list<!torch.list<!torch.list<!torch.tensor>>>
// CHECK-LABEL: torch.global_slot @l : !torch.list<list<list<tensor>>> {
// CHECK: %[[L0:.*]] = torch.prim.ListConstruct : () -> !torch.list<tensor>
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[L0]], %[[L0]] : (!torch.list<tensor>, !torch.list<tensor>) -> !torch.list<list<tensor>>
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[L1]], %[[L1]] : (!torch.list<list<tensor>>, !torch.list<list<tensor>>) -> !torch.list<list<list<tensor>>>
// CHECK: torch.global_slot.init %[[L2]] : !torch.list<list<list<tensor>>>
// CHECK: }
torch.class_type @c {
torch.attr "l" : !torch.list<!torch.list<!torch.list<!torch.tensor>>>
torch.attr "l" : !torch.list<list<list<tensor>>>
}
%l0 = torch.prim.ListConstruct : () -> !torch.list<!torch.tensor>
%l1 = torch.prim.ListConstruct %l0, %l0 : (!torch.list<!torch.tensor>, !torch.list<!torch.tensor>) -> !torch.list<!torch.list<!torch.tensor>>
%l2 = torch.prim.ListConstruct %l1, %l1 : (!torch.list<!torch.list<!torch.tensor>>, !torch.list<!torch.list<!torch.tensor>>) -> !torch.list<!torch.list<!torch.list<!torch.tensor>>>
%l0 = torch.prim.ListConstruct : () -> !torch.list<tensor>
%l1 = torch.prim.ListConstruct %l0, %l0 : (!torch.list<tensor>, !torch.list<tensor>) -> !torch.list<list<tensor>>
%l2 = torch.prim.ListConstruct %l1, %l1 : (!torch.list<list<tensor>>, !torch.list<list<tensor>>) -> !torch.list<list<list<tensor>>>
torch.nn_module {
torch.slot "l", %l2 : !torch.list<!torch.list<!torch.list<!torch.tensor>>>
torch.slot "l", %l2 : !torch.list<list<list<tensor>>>
} : !torch.nn.Module<"c">

View File

@ -6,7 +6,7 @@ torch.class_type @parent {
func private @module_type_return(%arg0: !torch.nn.Module<"parent">) {
// expected-error @+1 {{unsupported use of a torch.nn.Module. Expected only method calls or attribute get/set}}
torch.prim.ListConstruct %arg0 : (!torch.nn.Module<"parent">) -> !torch.list<!torch.nn.Module<"parent">>
torch.prim.ListConstruct %arg0 : (!torch.nn.Module<"parent">) -> !torch.list<nn.Module<"parent">>
return
}

View File

@ -56,18 +56,18 @@ func @none_call_return() {
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] :
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
return %1 : !torch.tuple<!torch.tensor, !torch.tensor>
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
return %1 : !torch.tuple<tensor, tensor>
}
// CHECK-LABEL: func @call_tuple_return(
@ -84,16 +84,16 @@ func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f
// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) :
// CHECK-SAME: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 :
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor>
return %0 : !torch.tuple<!torch.tensor, !torch.tensor>
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
return %0 : !torch.tuple<tensor, tensor>
}

View File

@ -24,17 +24,17 @@ func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !
// CHECK-LABEL: func @torch.aten.__is__
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool
func @torch.aten.__is__(%arg0: !torch.list<!torch.int>, %arg1: !torch.none) -> !torch.bool {
%0 = torch.aten.__is__ %arg0, %arg1 : !torch.list<!torch.int>, !torch.none -> !torch.bool
func @torch.aten.__is__(%arg0: !torch.list<int>, %arg1: !torch.none) -> !torch.bool {
%0 = torch.aten.__is__ %arg0, %arg1 : !torch.list<int>, !torch.none -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.__is__$derefine_is_none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool
func @torch.aten.__is__$derefine_is_none(%arg0: !torch.list<!torch.int>, %arg1: !torch.none) -> !torch.bool {
%0 = torch.derefine %arg0 : !torch.list<!torch.int> to !torch.optional<!torch.list<!torch.int>>
%1 = torch.aten.__is__ %0, %arg1 : !torch.optional<!torch.list<!torch.int>>, !torch.none -> !torch.bool
func @torch.aten.__is__$derefine_is_none(%arg0: !torch.list<int>, %arg1: !torch.none) -> !torch.bool {
%0 = torch.derefine %arg0 : !torch.list<int> to !torch.optional<list<int>>
%1 = torch.aten.__is__ %0, %arg1 : !torch.optional<list<int>>, !torch.none -> !torch.bool
return %1 : !torch.bool
}
@ -52,16 +52,16 @@ func @torch.aten.__is__$none_is_none(%arg0: !torch.none, %arg1: !torch.none) ->
// CHECK: return %[[RESULT]] : !torch.bool
func @torch.aten.__is__$is_none$derefine(%arg0: !torch.vtensor) -> !torch.bool {
%none = torch.constant.none
%0 = torch.derefine %arg0 : !torch.vtensor to !torch.optional<!torch.vtensor>
%1 = torch.aten.__is__ %0, %none : !torch.optional<!torch.vtensor>, !torch.none -> !torch.bool
%0 = torch.derefine %arg0 : !torch.vtensor to !torch.optional<vtensor>
%1 = torch.aten.__is__ %0, %none : !torch.optional<vtensor>, !torch.none -> !torch.bool
return %1 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.__isnot__
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func @torch.aten.__isnot__(%arg0: !torch.list<!torch.int>, %arg1: !torch.none) -> !torch.bool {
%0 = torch.aten.__isnot__ %arg0, %arg1 : !torch.list<!torch.int>, !torch.none -> !torch.bool
func @torch.aten.__isnot__(%arg0: !torch.list<int>, %arg1: !torch.none) -> !torch.bool {
%0 = torch.aten.__isnot__ %arg0, %arg1 : !torch.list<int>, !torch.none -> !torch.bool
return %0 : !torch.bool
}
@ -104,25 +104,25 @@ func @torch.aten.ne.bool$different_operand(%a: !torch.bool) -> !torch.bool {
}
// CHECK-LABEL: func @torch.aten.size$canonicalize_to_list(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list<!torch.int> {
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list<int> {
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C2]], %[[C3]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: return %[[LIST]] : !torch.list<!torch.int>
func @torch.aten.size$canonicalize_to_list(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.list<!torch.int> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[2,3],f32> -> !torch.list<!torch.int>
return %0 : !torch.list<!torch.int>
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C2]], %[[C3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[LIST]] : !torch.list<int>
func @torch.aten.size$canonicalize_to_list(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.list<int> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[2,3],f32> -> !torch.list<int>
return %0 : !torch.list<int>
}
// One size unknown, so cannot canonicalize.
// TODO: For unknown sizes, insert the equivalent of a "dim" op.
// Then this will only require static rank.
// CHECK-LABEL: func @torch.aten.size$unknown_size(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor<[?,3],f32> -> !torch.list<!torch.int>
func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<!torch.int>
return %0 : !torch.list<!torch.int>
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor<[?,3],f32> -> !torch.list<int>
func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<int>
return %0 : !torch.list<int>
}
// CHECK-LABEL: func @torch.aten.ne.int$same_operand(
@ -469,8 +469,8 @@ func @torch.prim.min.self_int$basic() -> !torch.int {
%int-1 = torch.constant.int -1
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int-1, %int0, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.prim.min.self_int %0 : !torch.list<!torch.int> -> !torch.int
%0 = torch.prim.ListConstruct %int-1, %int0, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.min.self_int %0 : !torch.list<int> -> !torch.int
return %1 : !torch.int
}
@ -479,8 +479,8 @@ func @torch.prim.min.self_int$basic() -> !torch.int {
func @torch.prim.min.self_int$nofold$dynamic(%arg0: !torch.int) -> !torch.int {
%int-1 = torch.constant.int -1
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %int-1, %int0, %arg0: (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.prim.min.self_int %0 : !torch.list<!torch.int> -> !torch.int
%0 = torch.prim.ListConstruct %int-1, %int0, %arg0: (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.min.self_int %0 : !torch.list<int> -> !torch.int
return %1 : !torch.int
}
@ -489,8 +489,8 @@ func @torch.prim.min.self_int$nofold$dynamic(%arg0: !torch.int) -> !torch.int {
// CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> !torch.int
// CHECK: return %[[DIM]] : !torch.int
func @torch.aten.len.t$of_size(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
%0 = torch.aten.size %arg0 : !torch.vtensor<*,f32> -> !torch.list<!torch.int>
%1 = torch.aten.len.t %0 : !torch.list<!torch.int> -> !torch.int
%0 = torch.aten.size %arg0 : !torch.vtensor<*,f32> -> !torch.list<int>
%1 = torch.aten.len.t %0 : !torch.list<int> -> !torch.int
return %1 : !torch.int
}
@ -508,18 +508,18 @@ func @torch.aten.dim$with_shape(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.in
// CHECK: %[[LEN:.*]] = torch.constant.int 4
// CHECK: return %[[LEN]] : !torch.int
func @torch.aten.len.t$of_build_list(%arg0: !torch.int) -> !torch.int {
%0 = torch.prim.ListConstruct %arg0, %arg0, %arg0, %arg0 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.len.t %0 : !torch.list<!torch.int> -> !torch.int
%0 = torch.prim.ListConstruct %arg0, %arg0, %arg0, %arg0 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.len.t %0 : !torch.list<int> -> !torch.int
return %1 : !torch.int
}
// CHECK-LABEL: func @torch.aten.len.t$no_fold_list_mutated()
func @torch.aten.len.t$no_fold_list_mutated() -> !torch.int {
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
%1 = torch.aten.append.t %0, %int4 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.append.t %0, %int4 : !torch.list<int>, !torch.int -> !torch.list<int>
// CHECK: torch.aten.len.t
%2 = torch.aten.len.t %0 : !torch.list<!torch.int> -> !torch.int
%2 = torch.aten.len.t %0 : !torch.list<int> -> !torch.int
return %2 : !torch.int
}
@ -530,8 +530,8 @@ func @torch.aten.__getitem__.t() -> !torch.int {
%int4 = torch.constant.int 4
%int5 = torch.constant.int 5
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.__getitem__.t %0, %int1 : !torch.list<!torch.int>, !torch.int -> !torch.int
%0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.__getitem__.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.int
return %1 : !torch.int
}
@ -539,25 +539,25 @@ func @torch.aten.__getitem__.t() -> !torch.int {
// CHECK-LABEL: func @torch.aten.__getitem__.t$no_change_test0(
// CHECK: %[[C4:.*]] = torch.constant.int 4
// CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C4]], %[[C5]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %[[LIST]], %arg0 : !torch.list<!torch.int>, !torch.int -> !torch.int
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C4]], %[[C5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %[[LIST]], %arg0 : !torch.list<int>, !torch.int -> !torch.int
// CHECK: return %[[ITEM]] : !torch.int
func @torch.aten.__getitem__.t$no_change_test0(%arg0: !torch.int) -> !torch.int {
%int5 = torch.constant.int 5
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.__getitem__.t %0, %arg0 : !torch.list<!torch.int>, !torch.int -> !torch.int
%0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.__getitem__.t %0, %arg0 : !torch.list<int>, !torch.int -> !torch.int
return %1 : !torch.int
}
// Not canonicalized because of passed in list
// CHECK-LABEL: func @torch.aten.__getitem__.t$no_change_test1(
// CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %arg0, %[[C5]] : !torch.list<!torch.int>, !torch.int -> !torch.int
// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %arg0, %[[C5]] : !torch.list<int>, !torch.int -> !torch.int
// CHECK: return %[[ITEM]] : !torch.int
func @torch.aten.__getitem__.t$no_change_test1(%arg0: !torch.list<!torch.int>) -> !torch.int {
func @torch.aten.__getitem__.t$no_change_test1(%arg0: !torch.list<int>) -> !torch.int {
%int5 = torch.constant.int 5
%0 = torch.aten.__getitem__.t %arg0, %int5 : !torch.list<!torch.int>, !torch.int -> !torch.int
%0 = torch.aten.__getitem__.t %arg0, %int5 : !torch.list<int>, !torch.int -> !torch.int
return %0 : !torch.int
}
@ -567,8 +567,8 @@ func @torch.aten.__getitem__.t$no_change_test1(%arg0: !torch.list<!torch.int>) -
// CHECK: %[[RESULT:.*]] = torch.aten.size.int %[[TENSOR]], %[[INDEX]] : !torch.tensor, !torch.int -> !torch.int
// CHECK: return %[[RESULT]] : !torch.int
func @torch.aten.__getitem__.t$getitem_of_size(%arg0: !torch.tensor, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.size %arg0 : !torch.tensor -> !torch.list<!torch.int>
%1 = torch.aten.__getitem__.t %0, %arg1 : !torch.list<!torch.int>, !torch.int -> !torch.int
%0 = torch.aten.size %arg0 : !torch.tensor -> !torch.list<int>
%1 = torch.aten.__getitem__.t %0, %arg1 : !torch.list<int>, !torch.int -> !torch.int
return %1 : !torch.int
}
@ -579,8 +579,8 @@ func @torch.aten.__getitem__.t$negative_index() -> !torch.int {
%int7 = torch.constant.int 7
%int8 = torch.constant.int 8
%int-1 = torch.constant.int -1
%0 = torch.prim.ListConstruct %int7, %int8 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.__getitem__.t %0, %int-1 : !torch.list<!torch.int>, !torch.int -> !torch.int
%0 = torch.prim.ListConstruct %int7, %int8 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.__getitem__.t %0, %int-1 : !torch.list<int>, !torch.int -> !torch.int
return %1 : !torch.int
}
@ -589,9 +589,9 @@ func @torch.aten.__getitem__.t$invalid_index() -> !torch.int {
%int7 = torch.constant.int 7
%int8 = torch.constant.int 8
%int-1 = torch.constant.int -100
%0 = torch.prim.ListConstruct %int7, %int8 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%0 = torch.prim.ListConstruct %int7, %int8 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.__getitem__.t
%1 = torch.aten.__getitem__.t %0, %int-1 : !torch.list<!torch.int>, !torch.int -> !torch.int
%1 = torch.aten.__getitem__.t %0, %int-1 : !torch.list<int>, !torch.int -> !torch.int
return %1 : !torch.int
}
@ -599,9 +599,9 @@ func @torch.aten.__getitem__.t$invalid_index() -> !torch.int {
// CHECK: %[[RET:.*]] = torch.constant.bool false
// CHECK: return %[[RET]] : !torch.bool
func @torch.aten.eq.int_list$fold$literals_of_different_sizes(%arg0: !torch.int) -> !torch.bool {
%0 = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<!torch.int>
%2 = torch.aten.eq.int_list %0, %1 : !torch.list<!torch.int>, !torch.list<!torch.int> -> !torch.bool
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>
%2 = torch.aten.eq.int_list %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.bool
return %2 : !torch.bool
}
@ -609,18 +609,18 @@ func @torch.aten.eq.int_list$fold$literals_of_different_sizes(%arg0: !torch.int)
// CHECK: %[[RET:.*]] = torch.constant.bool true
// CHECK: return %[[RET]] : !torch.bool
func @torch.aten.eq.int_list$fold$same_literal(%arg0: !torch.int) -> !torch.bool {
%0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<!torch.int>
%2 = torch.aten.eq.int_list %0, %1 : !torch.list<!torch.int>, !torch.list<!torch.int> -> !torch.bool
%0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>
%2 = torch.aten.eq.int_list %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.eq.int_list$no_fold$different_literals(
func @torch.aten.eq.int_list$no_fold$different_literals(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<!torch.int>
%0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
// CHECK: torch.aten.eq.int_list
%2 = torch.aten.eq.int_list %0, %1 : !torch.list<!torch.int>, !torch.list<!torch.int> -> !torch.bool
%2 = torch.aten.eq.int_list %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.bool
return %2 : !torch.bool
}
@ -746,8 +746,8 @@ func @torch.prim.If$fold_same_result$subset_of_results(%arg0: !torch.bool, %arg1
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[ARG0]] : !torch.tensor
func @torch.prim.TupleUnpack(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor{
%123 = torch.prim.TupleConstruct %arg0, %arg1: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
%124:2 = torch.prim.TupleUnpack %123 : !torch.tuple<!torch.tensor, !torch.tensor> -> !torch.tensor, !torch.tensor
%123 = torch.prim.TupleConstruct %arg0, %arg1: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
%124:2 = torch.prim.TupleUnpack %123 : !torch.tuple<tensor, tensor> -> !torch.tensor, !torch.tensor
return %124#0 : !torch.tensor
}
@ -760,11 +760,11 @@ func @torch.prim.TupleUnpack(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !tor
// CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
// CHECK-SAME: keys(%[[K0]], %[[K1]] : !torch.str, !torch.str)
// CHECK-SAME: values(%[[V0]], %[[V1]] : !torch.tensor, !torch.tensor)
// CHECK-SAME: -> !torch.dict<!torch.str, !torch.tensor>
// CHECK-SAME: -> !torch.dict<str, tensor>
// CHECK: return %[[TRUE]] : !torch.bool
func @torch.aten.__contains__.str(%k0 : !torch.str, %v0: !torch.tensor, %k1: !torch.str, %v1: !torch.tensor) -> !torch.bool{
%dict = torch.prim.DictConstruct keys(%k0, %k1: !torch.str, !torch.str) values(%v0, %v1: !torch.tensor, !torch.tensor) -> !torch.dict<!torch.str, !torch.tensor>
%pred = torch.aten.__contains__.str %dict, %k0 : !torch.dict<!torch.str, !torch.tensor>, !torch.str -> !torch.bool
%dict = torch.prim.DictConstruct keys(%k0, %k1: !torch.str, !torch.str) values(%v0, %v1: !torch.tensor, !torch.tensor) -> !torch.dict<str, tensor>
%pred = torch.aten.__contains__.str %dict, %k0 : !torch.dict<str, tensor>, !torch.str -> !torch.bool
return %pred : !torch.bool
}
@ -774,17 +774,17 @@ func @torch.aten.__contains__.str(%k0 : !torch.str, %v0: !torch.tensor, %k1: !to
// CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
// CHECK-SAME: keys(%[[K0]], %[[K1]] : !torch.str, !torch.str)
// CHECK-SAME: values(%[[V0]], %[[V1]] : !torch.tensor, !torch.tensor)
// CHECK-SAME: -> !torch.dict<!torch.str, !torch.tensor>
// CHECK-SAME: -> !torch.dict<str, tensor>
// CHECK: torch.aten._set_item.str %[[DICT]], %[[K0]], %[[V1]] :
// CHECK-SAME: !torch.dict<!torch.str, !torch.tensor>, !torch.str, !torch.tensor
// CHECK-SAME: !torch.dict<str, tensor>, !torch.str, !torch.tensor
// CHECK: %[[RET:.*]] = torch.aten.__contains__.str %[[DICT]], %[[K0]] :
// CHECK-SAME: !torch.dict<!torch.str, !torch.tensor>, !torch.str -> !torch.bool
// CHECK-SAME: !torch.dict<str, tensor>, !torch.str -> !torch.bool
// CHECK: return %[[RET]] : !torch.bool
func @torch.aten.__contains__.str$with_dict_modified(%k0 : !torch.str, %v0: !torch.tensor, %k1: !torch.str, %v1: !torch.tensor) -> !torch.bool{
%dict = torch.prim.DictConstruct keys(%k0, %k1: !torch.str, !torch.str) values(%v0, %v1: !torch.tensor, !torch.tensor) -> !torch.dict<!torch.str, !torch.tensor>
torch.aten._set_item.str %dict, %k0, %v1 : !torch.dict<!torch.str, !torch.tensor>, !torch.str, !torch.tensor
%pred = torch.aten.__contains__.str %dict, %k0 : !torch.dict<!torch.str, !torch.tensor>, !torch.str -> !torch.bool
%dict = torch.prim.DictConstruct keys(%k0, %k1: !torch.str, !torch.str) values(%v0, %v1: !torch.tensor, !torch.tensor) -> !torch.dict<str, tensor>
torch.aten._set_item.str %dict, %k0, %v1 : !torch.dict<str, tensor>, !torch.str, !torch.tensor
%pred = torch.aten.__contains__.str %dict, %k0 : !torch.dict<str, tensor>, !torch.str -> !torch.bool
return %pred : !torch.bool
}
@ -794,11 +794,11 @@ func @torch.aten.__contains__.str$with_dict_modified(%k0 : !torch.str, %v0: !tor
// CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
// CHECK-SAME: keys(%[[K0]], %[[K1]] : !torch.str, !torch.str)
// CHECK-SAME: values(%[[V0]], %[[V1]] : !torch.tensor, !torch.tensor)
// CHECK-SAME: -> !torch.dict<!torch.str, !torch.tensor>
// CHECK-SAME: -> !torch.dict<str, tensor>
// CHECK: return %[[V0]] : !torch.tensor
func @torch.aten.__getitem__.Dict_str(%k0 : !torch.str, %v0: !torch.tensor, %k1: !torch.str, %v1: !torch.tensor) -> !torch.tensor {
%dict = torch.prim.DictConstruct keys(%k0, %k1: !torch.str, !torch.str) values(%v0, %v1: !torch.tensor, !torch.tensor) -> !torch.dict<!torch.str, !torch.tensor>
%v = torch.aten.__getitem__.Dict_str %dict, %k0 : !torch.dict<!torch.str, !torch.tensor>, !torch.str -> !torch.tensor
%dict = torch.prim.DictConstruct keys(%k0, %k1: !torch.str, !torch.str) values(%v0, %v1: !torch.tensor, !torch.tensor) -> !torch.dict<str, tensor>
%v = torch.aten.__getitem__.Dict_str %dict, %k0 : !torch.dict<str, tensor>, !torch.str -> !torch.tensor
return %v : !torch.tensor
}
@ -924,18 +924,18 @@ func @torch.aten.size.int$invalid_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.in
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
// CHECK: return %[[ARG]] : !torch.int
func @torch.prim.unchecked_cast$derefine_identity(%arg0: !torch.int) -> !torch.int {
%0 = torch.derefine %arg0 : !torch.int to !torch.optional<!torch.int>
%1 = torch.prim.unchecked_cast %0 : !torch.optional<!torch.int> -> !torch.int
%0 = torch.derefine %arg0 : !torch.int to !torch.optional<int>
%1 = torch.prim.unchecked_cast %0 : !torch.optional<int> -> !torch.int
return %1 : !torch.int
}
// CHECK-LABEL: func @torch.derefine$of_unchecked_cast(
// CHECK-SAME: %[[ARG:.*]]: !torch.optional<!torch.int>) -> !torch.optional<!torch.int> {
// CHECK: return %[[ARG]] : !torch.optional<!torch.int>
func @torch.derefine$of_unchecked_cast(%arg0: !torch.optional<!torch.int>) -> !torch.optional<!torch.int> {
%0 = torch.prim.unchecked_cast %arg0 : !torch.optional<!torch.int> -> !torch.int
%1 = torch.derefine %0 : !torch.int to !torch.optional<!torch.int>
return %1 : !torch.optional<!torch.int>
// CHECK-SAME: %[[ARG:.*]]: !torch.optional<int>) -> !torch.optional<int> {
// CHECK: return %[[ARG]] : !torch.optional<int>
func @torch.derefine$of_unchecked_cast(%arg0: !torch.optional<int>) -> !torch.optional<int> {
%0 = torch.prim.unchecked_cast %arg0 : !torch.optional<int> -> !torch.int
%1 = torch.derefine %0 : !torch.int to !torch.optional<int>
return %1 : !torch.optional<int>
}
// CHECK-LABEL: func @torch.tensor_static_info_cast$downcast_first(
@ -981,9 +981,9 @@ func @torch.tensor_static_info_cast$no_refine(%arg0: !torch.vtensor) -> !torch.v
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[T1]] : !torch.tensor
func @torch.prim.TupleIndex(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor, tensor>
%int1 = torch.constant.int 1
%1 = torch.prim.TupleIndex %0, %int1 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
%1 = torch.prim.TupleIndex %0, %int1 : !torch.tuple<tensor, tensor, tensor>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}
@ -992,23 +992,23 @@ func @torch.prim.TupleIndex(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.
// CHECK: %[[INDEX3:.*]] = torch.constant.int 3
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]], %[[T2]] :
// CHECK-SAME: !torch.tensor, !torch.tensor, !torch.tensor ->
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
// CHECK-SAME: !torch.tuple<tensor, tensor, tensor>
// CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INDEX3]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK-SAME: !torch.tuple<tensor, tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor, tensor>
%int3 = torch.constant.int 3
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<tensor, tensor, tensor>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}
// CHECK-LABEL: func @torch.prim.unchecked_cast$derefine
// CHECK-next: return %arg0 : !torch.list<!torch.int>
func @torch.prim.unchecked_cast$derefine(%arg0: !torch.list<!torch.int>) -> !torch.list<!torch.int> {
%0 = torch.derefine %arg0 : !torch.list<!torch.int> to !torch.optional<!torch.list<!torch.int>>
%1 = torch.prim.unchecked_cast %0 : !torch.optional<!torch.list<!torch.int>> -> !torch.list<!torch.int>
return %1 : !torch.list<!torch.int>
// CHECK-next: return %arg0 : !torch.list<int>
func @torch.prim.unchecked_cast$derefine(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = torch.derefine %arg0 : !torch.list<int> to !torch.optional<list<int>>
%1 = torch.prim.unchecked_cast %0 : !torch.optional<list<int>> -> !torch.list<int>
return %1 : !torch.list<int>
}
// CHECK-LABEL: func @torch.aten.Int.Tensor(
@ -1076,7 +1076,7 @@ func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch.tens
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?],f32>
func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32> {
%int-1 = torch.constant.int -1
%0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list<!torch.int> -> !torch.tensor<[?],f32>
%0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list<int> -> !torch.tensor<[?],f32>
return %1 : !torch.tensor<[?],f32>
}

View File

@ -37,11 +37,11 @@ func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.float -> !torch.tensor<[2,3],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIM_LIST]], %[[KEEP_DIM]], %[[SUM_DTYPE]] :
// CHECK-SAME: !torch.tensor<[2,3],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
// CHECK-SAME: !torch.tensor<[2,3],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[2,3],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor<[2,3],f32>
// CHECK: return %[[RET]] : !torch.tensor<[2,3],f32>
@ -64,11 +64,11 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
// CHECK-SAME: !torch.tensor<[2,1],f32>, !torch.float -> !torch.tensor<[2,3],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIM_LIST]], %[[KEEP_DIM]], %[[SUM_DTYPE]] :
// CHECK-SAME !torch.tensor<[2,3],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.tensor<[2,1],f32>
// CHECK-SAME !torch.tensor<[2,3],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor<[2,1],f32>
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[2,3],f32>, !torch.tensor<[2,1],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor<[2,3],f32>
// CHECK: return %[[RET]] : !torch.tensor<[2,3],f32>
@ -91,11 +91,11 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[?,?],f32>,
// CHECK-SAME: !torch.tensor<[?,1],f32>, !torch.float -> !torch.tensor<[?,?],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIM_LIST]], %[[KEEP_DIM]], %[[SUM_DTYPE]] :
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.tensor<[?,1],f32>
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor<[?,1],f32>
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[?,?],f32>, !torch.tensor<[?,1],f32> -> !torch.tensor<[?,?],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[?,?],f32> to !torch.tensor<[?,?],f32>
// CHECK: return %[[RET]] : !torch.tensor<[?,?],f32>
@ -118,11 +118,11 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<*,f32>, !torch.tensor<*,f32>,
// CHECK-SAME: !torch.float -> !torch.tensor<*,f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIM_LIST]], %[[KEEP_DIM]], %[[SUM_DTYPE]] :
// CHECK-SAME: !torch.tensor<*,f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.tensor<*,f32>
// CHECK-SAME: !torch.tensor<*,f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.tensor<*,f32>
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<*,f32>, !torch.tensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor<*,f32>
// CHECK: return %[[RET]] : !torch.tensor<*,f32>
@ -135,16 +135,16 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
// -----
// CHECK-LABEL: func @torch.aten.size(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[T]], %[[CST0]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[T]], %[[CST1]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: return %[[SIZE]] : !torch.list<!torch.int>
func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<!torch.int>
return %0 : !torch.list<!torch.int>
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[SIZE]] : !torch.list<int>
func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----
@ -327,8 +327,8 @@ func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !
%c2 = torch.constant.int 2
%c256 = torch.constant.int 256
%c32 = torch.constant.int 32
%0 = torch.prim.ListConstruct %c1, %c2, %c256, %c32 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[1,512,32],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,2,256,32],f32>
%0 = torch.prim.ListConstruct %c1, %c2, %c256, %c32 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[1,512,32],f32>, !torch.list<int> -> !torch.vtensor<[1,2,256,32],f32>
return %1 : !torch.vtensor<[1,2,256,32],f32>
}
@ -342,8 +342,8 @@ func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !
func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[512,32],f32> {
%c256 = torch.constant.int 512
%c32 = torch.constant.int 32
%0 = torch.prim.ListConstruct %c256, %c32 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[512,32],f32>
%0 = torch.prim.ListConstruct %c256, %c32 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[512,32],f32>
return %1 : !torch.vtensor<[512,32],f32>
}
@ -358,11 +358,11 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[INP]], %[[VAL]], %[[FLOAT1]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[PRIM:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[PRIM:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRU:.*]] = torch.constant.bool true
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SUM_DIM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[PRIM]], %[[TRU]], %[[NONE]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[1,?,?],f32>
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,?,?],f32>
// CHECK: %[[LOG:.*]] = torch.aten.log %[[SUM_DIM]] : !torch.vtensor<[1,?,?],f32> -> !torch.vtensor<[1,?,?],f32>
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB1:.*]] = torch.aten.sub.Tensor %[[SUB]], %[[LOG]], %[[FLOAT_1]] : !torch.vtensor<[?,?,?],f32>,
@ -492,16 +492,16 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %[[INPUT]], %[[CST2]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[ADD]], %[[CST6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[CST1_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[CST1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[CST1_TENSOR]], %[[DIV]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[EMPTY_1:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] :
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[CST0_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_1]], %[[CST0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32>
@ -520,10 +520,10 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %[[INP]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[RELU:.*]] = torch.aten.relu %[[ADD]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[INT6_:.*]] = torch.constant.int 6
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[FILL:.*]] = torch.pseudo.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
@ -539,16 +539,16 @@ func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[MIN_VAL:.*]]: !torch.float,
// CHECK-SAME: %[[MAX_VAL:.*]]: !torch.float) -> !torch.vtensor<[?],f32> {
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MIN_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[MIN_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[MIN:.*]] = torch.aten.maximum %[[INPUT]], %[[MIN_TENSOR]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?],f32>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[VAL_10:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK-SAME: !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[MAX_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?],f32>
@ -563,16 +563,16 @@ func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %m
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
// CHECK: }
func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.new_zeros %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.new_zeros %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
return %1 : !torch.vtensor<[2,3],f32>
}
@ -582,16 +582,16 @@ func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
// CHECK: return %[[RES]] : !torch.vtensor<[3,4],si64>
// CHECK: }
func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
%none = torch.constant.none
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.new_ones %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],si64>, !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.new_ones %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],si64>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
return %1 : !torch.vtensor<[3,4],si64>
}
@ -613,9 +613,9 @@ func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.v
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MEM_FORMAT:.*]] = torch.constant.none
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: %[[RES:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
@ -623,8 +623,8 @@ func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.full %0, %float5.000000e00, %none, %none, %none, %none : !torch.list<!torch.int>, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.full %0, %float5.000000e00, %none, %none, %none, %none : !torch.list<int>, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
return %1 : !torch.vtensor<[2,3],f32>
}
@ -637,8 +637,8 @@ func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INP]], %[[INT0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INP]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: %[[RES:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {

View File

@ -13,8 +13,8 @@ func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
torch.shape.calculate.yield %2 : !torch.vtensor<[2,?],unk>
} shapes {
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[2,?],unk>, !torch.int -> !torch.int
%3 = torch.prim.ListConstruct %int2, %2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %3 : !torch.list<!torch.int>
%3 = torch.prim.ListConstruct %int2, %2 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %3 : !torch.list<int>
} : !torch.vtensor<[2,?],unk>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[2,?],unk> to !torch.vtensor
return %1 : !torch.vtensor

View File

@ -115,18 +115,18 @@ builtin.func @f(%arg0: i32 {torch.type_bound = i32})
// -----
builtin.func @derefine(%arg0: !torch.optional<!torch.tensor>) -> !torch.tensor {
// expected-error @+1 {{operand type '!torch.optional<!torch.tensor>' and result type '!torch.tensor' are cast incompatible}}
%0 = torch.derefine %arg0 : !torch.optional<!torch.tensor> to !torch.tensor
builtin.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
// expected-error @+1 {{operand type '!torch.optional<tensor>' and result type '!torch.tensor' are cast incompatible}}
%0 = torch.derefine %arg0 : !torch.optional<tensor> to !torch.tensor
return %0 : !torch.tensor
}
// -----
builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<!torch.tensor> {
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<!torch.tensor>' are cast incompatible}}
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<!torch.tensor>
return %0 : !torch.optional<!torch.tensor>
builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<tensor> {
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<tensor>' are cast incompatible}}
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<tensor>
return %0 : !torch.optional<tensor>
}
// -----
@ -166,7 +166,7 @@ builtin.func @torch.tensor() {
builtin.func @torch.prim.ListConstruct() {
%int2 = torch.constant.int 2
// expected-error@+1 {{operand types should have the same type as the list contained type}}
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<!torch.tensor>
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<tensor>
return
}

View File

@ -46,25 +46,25 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
}
// CHECK-LABEL: func @mutation_followed_by_view_like_ops(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<!torch.int>) -> !torch.vtensor {
// CHECK: %[[VIEW:.*]] = torch.aten.view %[[OVERWRITER]], %[[INT_LIST]] : !torch.vtensor, !torch.list<!torch.int> -> !torch.vtensor
// CHECK: %[[RESULT:.*]] = torch.aten.permute %[[VIEW]], %[[INT_LIST]] : !torch.vtensor, !torch.list<!torch.int> -> !torch.vtensor
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: %[[VIEW:.*]] = torch.aten.view %[[OVERWRITER]], %[[INT_LIST]] : !torch.vtensor, !torch.list<int> -> !torch.vtensor
// CHECK: %[[RESULT:.*]] = torch.aten.permute %[[VIEW]], %[[INT_LIST]] : !torch.vtensor, !torch.list<int> -> !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<int> -> !torch.tensor
%result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<int> -> !torch.tensor
%value_result = torch.copy.to_vtensor %result : !torch.vtensor
return %value_result : !torch.vtensor
}
// CHECK-LABEL: func @mutation_of_view_like_op_result(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<!torch.int>) -> !torch.vtensor {
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: return %[[OVERWRITER]] : !torch.vtensor
func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<int> -> !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %view : !torch.vtensor, !torch.tensor
%result = torch.copy.to_vtensor %view : !torch.vtensor
return %result : !torch.vtensor

View File

@ -30,10 +30,13 @@ func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: @tuple.empty() -> !torch.tuple<>
func private @tuple.empty() -> !torch.tuple<>
// CHECK: @tuple.one_element() -> !torch.tuple<!torch.tensor>
func private @tuple.one_element() -> !torch.tuple<!torch.tensor>
// CHECK: @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
func private @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK: @tuple.one_element() -> !torch.tuple<tensor>
func private @tuple.one_element() -> !torch.tuple<tensor>
// CHECK: @tuple.two_elements() -> !torch.tuple<tensor, tensor>
func private @tuple.two_elements() -> !torch.tuple<tensor, tensor>
// CHECK: @dict() -> !torch.dict<str, tensor>
func private @dict() -> !torch.dict<str, tensor>
// CHECK-LABEL: func @torch.tensor.literal() {
func @torch.tensor.literal() {
@ -51,9 +54,9 @@ func @torch.vtensor.literal() {
return
}
func @derefine(%arg0: !torch.tensor) -> !torch.optional<!torch.tensor> {
%0 = torch.derefine %arg0 : !torch.tensor to !torch.optional<!torch.tensor>
return %0 : !torch.optional<!torch.tensor>
func @derefine(%arg0: !torch.tensor) -> !torch.optional<tensor> {
%0 = torch.derefine %arg0 : !torch.tensor to !torch.optional<tensor>
return %0 : !torch.optional<tensor>
}
func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
@ -106,7 +109,7 @@ torch.class_type @test {
torch.attr "f" : !torch.float
torch.attr "t" : !torch.tensor
torch.attr "submodule" : !torch.nn.Module<"empty">
torch.attr "ob" : !torch.optional<!torch.bool>
torch.attr "ob" : !torch.optional<bool>
torch.attr "s" : !torch.str
torch.method "method", @f
}
@ -126,8 +129,8 @@ func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
torch.shape.calculate.yield %0 : !torch.vtensor
} shapes {
%0 = torch.aten.size %arg0 : !torch.vtensor -> !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %0 : !torch.list<!torch.int>
%0 = torch.aten.size %arg0 : !torch.vtensor -> !torch.list<int>
torch.shape.calculate.yield.shapes %0 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}

View File

@ -2,16 +2,16 @@
// -----
func @convert_to_value_semantic_tensors_list( %list: !torch.list<!torch.tensor>) -> !torch.tensor {
func @convert_to_value_semantic_tensors_list( %list: !torch.list<tensor>) -> !torch.tensor {
%int1 = torch.constant.int 1
// expected-error@+1 {{failed to legalize operation 'torch.aten.cat' that was explicitly marked illegal}}
%ret = torch.aten.cat %list, %int1 : !torch.list<!torch.tensor>, !torch.int -> !torch.tensor
%ret = torch.aten.cat %list, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// -----
func @convert_to_value_semantic_tensors_optional(%tensor_optional: !torch.optional<!torch.tensor>,
func @convert_to_value_semantic_tensors_optional(%tensor_optional: !torch.optional<tensor>,
%t: !torch.tensor,
%training: !torch.bool,
%cudnn_enable: !torch.bool,
@ -19,8 +19,8 @@ func @convert_to_value_semantic_tensors_optional(%tensor_optional: !torch.option
// expected-error@+1 {{failed to legalize operation 'torch.aten.batch_norm' that was explicitly marked illegal}}
%ret = torch.aten.batch_norm %t, %tensor_optional, %tensor_optional, %tensor_optional,
%tensor_optional, %training, %f, %f, %cudnn_enable:
!torch.tensor, !torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>,
!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>,
!torch.tensor, !torch.optional<tensor>, !torch.optional<tensor>,
!torch.optional<tensor>, !torch.optional<tensor>,
!torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.tensor
return %ret: !torch.tensor
}

View File

@ -19,15 +19,15 @@ func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.
// CHECK: %[[T2:.*]] = torch.copy.to_tensor %[[VT2]] : !torch.tensor
// CHECK: %[[DIM:.*]] = torch.constant.int 1
// CHECK: %[[LIST_ORIG:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]], %[[T2]] :
// CHECK-SAME: (!torch.tensor, !torch.tensor, !torch.tensor) -> !torch.list<!torch.tensor>
// CHECK-SAME: (!torch.tensor, !torch.tensor, !torch.tensor) -> !torch.list<tensor>
// CHECK: %[[VT0_COPY:.*]] = torch.copy.to_vtensor %[[T0]] : !torch.vtensor
// CHECK: %[[VT1_COPY:.*]] = torch.copy.to_vtensor %[[T1]] : !torch.vtensor
// CHECK: %[[VT2_COPY:.*]] = torch.copy.to_vtensor %[[T2]] : !torch.vtensor
// CHECK: %[[LIST_NEW:.*]] = torch.prim.ListConstruct
// CHECK-SAME: %[[VT0_COPY]], %[[VT1_COPY]], %[[VT2_COPY]] :
// CHECK-SAME: (!torch.vtensor, !torch.vtensor, !torch.vtensor) -> !torch.list<!torch.vtensor>
// CHECK-SAME: (!torch.vtensor, !torch.vtensor, !torch.vtensor) -> !torch.list<vtensor>
// CHECK: %[[VRET:.*]] = torch.aten.cat %[[LIST_NEW]], %[[DIM]] :
// CHECK-SAME: !torch.list<!torch.vtensor>, !torch.int -> !torch.vtensor
// CHECK-SAME: !torch.list<vtensor>, !torch.int -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.vtensor, %vt2: !torch.vtensor) -> !torch.tensor {
@ -35,8 +35,8 @@ func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.
%t1 = torch.copy.to_tensor %vt1 : !torch.tensor
%t2 = torch.copy.to_tensor %vt2 : !torch.tensor
%int1 = torch.constant.int 1
%list = torch.prim.ListConstruct %t0, %t1, %t2 : (!torch.tensor, !torch.tensor, !torch.tensor) -> !torch.list<!torch.tensor>
%ret = torch.aten.cat %list, %int1 : !torch.list<!torch.tensor>, !torch.int -> !torch.tensor
%list = torch.prim.ListConstruct %t0, %t1, %t2 : (!torch.tensor, !torch.tensor, !torch.tensor) -> !torch.list<tensor>
%ret = torch.aten.cat %list, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
@ -46,23 +46,23 @@ func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FLOAT_TENSOR_OPTIONAL:.*]] = torch.derefine %[[FLOAT_TENSOR]] :
// CHECK-SAME: !torch.tensor<[4],f32> to !torch.optional<!torch.tensor>
// CHECK: %[[BIAS_NONE_OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.tensor>
// CHECK-SAME: !torch.tensor<[4],f32> to !torch.optional<tensor>
// CHECK: %[[BIAS_NONE_OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: %[[VINPUT:.*]] = torch.copy.to_vtensor %[[INPUT]] : !torch.vtensor
// CHECK: %[[FLOAT_VTENSOR:.*]] = torch.copy.to_vtensor %[[FLOAT_TENSOR]] : !torch.vtensor<[4],f32>
// CHECK: %[[WEIGHTS_TENSOR_OPTIONAL:.*]] = torch.derefine %[[FLOAT_VTENSOR]] :
// CHECK-SAME: !torch.vtensor<[4],f32> to !torch.optional<!torch.vtensor<[4],f32>>
// CHECK-SAME: !torch.vtensor<[4],f32> to !torch.optional<vtensor<[4],f32>>
// CHECK: %[[FLOAT_VTENSOR:.*]] = torch.copy.to_vtensor %[[FLOAT_TENSOR]] : !torch.vtensor<[4],f32>
// CHECK: %[[MEAN_VTENSOR_OPTIONAL:.*]] = torch.derefine %[[FLOAT_VTENSOR]] :
// CHECK-SAME: !torch.vtensor<[4],f32> to !torch.optional<!torch.vtensor<[4],f32>>
// CHECK-SAME: !torch.vtensor<[4],f32> to !torch.optional<vtensor<[4],f32>>
// CHECK: %[[FLOAT_VTENSOR:.*]] = torch.copy.to_vtensor %[[FLOAT_TENSOR]] : !torch.vtensor<[4],f32>
// CHECK: %[[VAR_VTENSOR_OPTIONAL:.*]] = torch.derefine %[[FLOAT_VTENSOR]] :
// CHECK-SAME: !torch.vtensor<[4],f32> to !torch.optional<!torch.vtensor<[4],f32>>
// CHECK-SAME: !torch.vtensor<[4],f32> to !torch.optional<vtensor<[4],f32>>
// CHECK: %[[VRET:.*]] = torch.aten.batch_norm %[[VINPUT]], %[[WEIGHTS_TENSOR_OPTIONAL]],
// CHECK-SAME: %[[BIAS_NONE_OPTIONAL]], %[[MEAN_VTENSOR_OPTIONAL]], %[[VAR_VTENSOR_OPTIONAL]],
// CHECK-SAME: %[[TRAINING]], %[[FLOAT]], %[[FLOAT]], %[[CUDNN_ENABLE]] :
// CHECK-SAME: !torch.vtensor, !torch.optional<!torch.vtensor<[4],f32>>, !torch.optional<!torch.tensor>,
// CHECK-SAME: !torch.optional<!torch.vtensor<[4],f32>>, !torch.optional<!torch.vtensor<[4],f32>>,
// CHECK-SAME: !torch.vtensor, !torch.optional<vtensor<[4],f32>>, !torch.optional<tensor>,
// CHECK-SAME: !torch.optional<vtensor<[4],f32>>, !torch.optional<vtensor<[4],f32>>,
// CHECK-SAME: !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
@ -73,12 +73,12 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
%cudnn_enable: !torch.bool,
%f : !torch.float) -> !torch.tensor {
%none = torch.constant.none
%tensor_optional = torch.derefine %ft: !torch.tensor<[4],f32> to !torch.optional<!torch.tensor>
%none_optional = torch.derefine %none : !torch.none to !torch.optional<!torch.tensor>
%tensor_optional = torch.derefine %ft: !torch.tensor<[4],f32> to !torch.optional<tensor>
%none_optional = torch.derefine %none : !torch.none to !torch.optional<tensor>
%ret = torch.aten.batch_norm %t, %tensor_optional, %none_optional, %tensor_optional,
%tensor_optional, %training, %f, %f, %cudnn_enable:
!torch.tensor, !torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>,
!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>,
!torch.tensor, !torch.optional<tensor>, !torch.optional<tensor>,
!torch.optional<tensor>, !torch.optional<tensor>,
!torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.tensor
return %ret: !torch.tensor
}
@ -117,16 +117,16 @@ func @torch.tensor.literal() -> !torch.tensor {
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
// CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDICES]] :
// CHECK-SAME: (!torch.tensor<[2,3],si64>) -> !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>>
// CHECK-SAME: (!torch.tensor<[2,3],si64>) -> !torch.list<optional<tensor<[2,3],si64>>>
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor<[5],f32>
// CHECK: %[[INDICES_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDICES]] : !torch.vtensor<[2,3],si64>
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDICES_VTENSOR]] : (!torch.vtensor<[2,3],si64>) -> !torch.list<!torch.vtensor<[2,3],si64>>
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<!torch.vtensor<[2,3],si64>> -> !torch.vtensor
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDICES_VTENSOR]] : (!torch.vtensor<[2,3],si64>) -> !torch.list<vtensor<[2,3],si64>>
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<vtensor<[2,3],si64>> -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor {
%tensor_optional_list = torch.prim.ListConstruct %indices : (!torch.tensor<[2,3],si64>) -> !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>>
%ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>> -> !torch.tensor
%tensor_optional_list = torch.prim.ListConstruct %indices : (!torch.tensor<[2,3],si64>) -> !torch.list<optional<tensor<[2,3],si64>>>
%ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<optional<tensor<[2,3],si64>>> -> !torch.tensor
return %ret : !torch.tensor
}

View File

@ -6,28 +6,28 @@
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T1:.*]]: !torch.tensor,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor) -> !torch.bool {
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<!torch.tensor>) {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T1]] : !torch.tensor to !torch.optional<!torch.tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<!torch.tensor>
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<tensor>) {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T1]] : !torch.tensor to !torch.optional<tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
// CHECK: } else {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T2]] : !torch.tensor to !torch.optional<!torch.tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<!torch.tensor>
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T2]] : !torch.tensor to !torch.optional<tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
// CHECK: }
// CHECK: %[[REFINED:.*]] = torch.prim.unchecked_cast %[[MERGED:.*]] : !torch.optional<!torch.tensor> -> !torch.tensor
// CHECK: %[[REFINED:.*]] = torch.prim.unchecked_cast %[[MERGED:.*]] : !torch.optional<tensor> -> !torch.tensor
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[REFINED]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool
// CHECK: return %[[RET]] : !torch.bool
func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %t1: !torch.tensor) -> !torch.bool {
%res = torch.prim.If %pred -> (!torch.optional<!torch.tensor>) {
%optional0 = torch.derefine %t0: !torch.tensor to !torch.optional<!torch.tensor>
torch.prim.If.yield %optional0: !torch.optional<!torch.tensor>
%res = torch.prim.If %pred -> (!torch.optional<tensor>) {
%optional0 = torch.derefine %t0: !torch.tensor to !torch.optional<tensor>
torch.prim.If.yield %optional0: !torch.optional<tensor>
} else {
%optional1 = torch.derefine %t1: !torch.tensor to !torch.optional<!torch.tensor>
torch.prim.If.yield %optional1: !torch.optional<!torch.tensor>
%optional1 = torch.derefine %t1: !torch.tensor to !torch.optional<tensor>
torch.prim.If.yield %optional1: !torch.optional<tensor>
}
%none = torch.constant.none
%cmp = torch.aten.__isnot__ %res, %none : !torch.optional<!torch.tensor>, !torch.none -> !torch.bool
%cmp = torch.aten.__isnot__ %res, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
return %cmp : !torch.bool
}
@ -35,37 +35,37 @@ func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %
// CHECK-LABEL: func @prim.if$branch_merge_type_optional(
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional<!torch.tensor> {
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<!torch.tensor>) {
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional<tensor> {
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<tensor>) {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<!torch.tensor>
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
// CHECK: } else {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T]] : !torch.tensor to !torch.optional<!torch.tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<!torch.tensor>
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T]] : !torch.tensor to !torch.optional<tensor>
// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional<tensor>
// CHECK: }
// CHECK: return %[[MERGED:.*]] : !torch.optional<!torch.tensor>
// CHECK: return %[[MERGED:.*]] : !torch.optional<tensor>
func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional<!torch.tensor> {
%res = torch.prim.If %pred -> (!torch.optional<!torch.tensor>) {
func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional<tensor> {
%res = torch.prim.If %pred -> (!torch.optional<tensor>) {
%none = torch.constant.none
%optional0 = torch.derefine %none: !torch.none to !torch.optional<!torch.tensor>
torch.prim.If.yield %optional0: !torch.optional<!torch.tensor>
%optional0 = torch.derefine %none: !torch.none to !torch.optional<tensor>
torch.prim.If.yield %optional0: !torch.optional<tensor>
} else {
%optional1 = torch.derefine %t1: !torch.tensor to !torch.optional<!torch.tensor>
torch.prim.If.yield %optional1: !torch.optional<!torch.tensor>
%optional1 = torch.derefine %t1: !torch.tensor to !torch.optional<tensor>
torch.prim.If.yield %optional1: !torch.optional<tensor>
}
return %res: !torch.optional<!torch.tensor>
return %res: !torch.optional<tensor>
}
// -----
// CHECK-LABEL: func @prim.if$refined_type_conflicting(
// CHECK-SAME: %[[NONE:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.tensor>
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: %[[NOT_NONE:.*]] = torch.aten.__isnot__ %[[NONE]], %[[NONE]] : !torch.none, !torch.none -> !torch.bool
// CHECK: %[[PRED:.*]] = torch.prim.If %[[NOT_NONE]] -> (!torch.tensor) {
// CHECK: %[[T:.*]] = torch.prim.unchecked_cast %[[OPTIONAL]] : !torch.optional<!torch.tensor> -> !torch.tensor
// CHECK: %[[T:.*]] = torch.prim.unchecked_cast %[[OPTIONAL]] : !torch.optional<tensor> -> !torch.tensor
// CHECK: torch.prim.If.yield %[[T]] : !torch.tensor
// CHECK: } else {
// CHECK: %[[LITERAL:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<3x5xf32>) : !torch.tensor
@ -74,10 +74,10 @@ func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor)
// CHECK: return %[[PRED:.*]] : !torch.tensor
func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
%optional = torch.derefine %none: !torch.none to !torch.optional<!torch.tensor>
%pred = torch.aten.__isnot__ %optional, %none : !torch.optional<!torch.tensor>, !torch.none -> !torch.bool
%optional = torch.derefine %none: !torch.none to !torch.optional<tensor>
%pred = torch.aten.__isnot__ %optional, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
%res = torch.prim.If %pred -> (!torch.tensor) {
%t = torch.prim.unchecked_cast %optional: !torch.optional<!torch.tensor> -> !torch.tensor
%t = torch.prim.unchecked_cast %optional: !torch.optional<tensor> -> !torch.tensor
torch.prim.If.yield %t: !torch.tensor
} else {
%t_cst = torch.tensor.literal(dense<0.0> : tensor<3x5xf32>) : !torch.tensor
@ -89,33 +89,33 @@ func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
// -----
// CHECK-LABEL: func @prim.loop$region_arg_to_internal(
// CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional<!torch.tensor> {
// CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional<tensor> {
// CHECK: %[[INT10:.*]] = torch.constant.int 10
// CHECK: %[[INDV:.*]] = torch.constant.int 0
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[ARG_NONE]] : !torch.none to !torch.optional<!torch.tensor>
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[ARG_NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: %[[LOOP_RET:.*]] = torch.prim.Loop %[[INT10]], %[[TRUE]], init(%[[OPTIONAL]]) {
// CHECK: ^bb0(%[[INDV:.*]]: !torch.int, %[[IT:.*]]: !torch.optional<!torch.tensor>):
// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[IT]] : !torch.optional<!torch.tensor> -> !torch.none
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.tensor>
// CHECK: ^bb0(%[[INDV:.*]]: !torch.int, %[[IT:.*]]: !torch.optional<tensor>):
// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[IT]] : !torch.optional<tensor> -> !torch.none
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: %[[COND:.*]] = torch.aten.__isnot__ %[[NONE]], %[[ARG_NONE]] : !torch.none, !torch.none -> !torch.bool
// CHECK: torch.prim.Loop.condition %[[COND]], iter(%[[OPTIONAL]] : !torch.optional<!torch.tensor>)
// CHECK: } : (!torch.int, !torch.bool, !torch.optional<!torch.tensor>) -> !torch.optional<!torch.tensor>
// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[LOOP_RET:.*]] : !torch.optional<!torch.tensor> -> !torch.none
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.tensor>
// CHECK: return %[[OPTIONAL]] : !torch.optional<!torch.tensor>
// CHECK: torch.prim.Loop.condition %[[COND]], iter(%[[OPTIONAL]] : !torch.optional<tensor>)
// CHECK: } : (!torch.int, !torch.bool, !torch.optional<tensor>) -> !torch.optional<tensor>
// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[LOOP_RET:.*]] : !torch.optional<tensor> -> !torch.none
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: return %[[OPTIONAL]] : !torch.optional<tensor>
func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<!torch.tensor> {
func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<tensor> {
%int10 = torch.constant.int 10
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%optional = torch.derefine %none: !torch.none to !torch.optional<!torch.tensor>
%optional = torch.derefine %none: !torch.none to !torch.optional<tensor>
%ret = torch.prim.Loop %int10, %true, init(%optional) {
^bb0(%arg2: !torch.int, %arg3: !torch.optional<!torch.tensor>): // no predecessors
%cond = torch.aten.__isnot__ %arg3, %none : !torch.optional<!torch.tensor>, !torch.none -> !torch.bool
torch.prim.Loop.condition %cond, iter(%arg3: !torch.optional<!torch.tensor>)
} : (!torch.int, !torch.bool, !torch.optional<!torch.tensor>) -> (!torch.optional<!torch.tensor>)
return %ret: !torch.optional<!torch.tensor>
^bb0(%arg2: !torch.int, %arg3: !torch.optional<tensor>): // no predecessors
%cond = torch.aten.__isnot__ %arg3, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
torch.prim.Loop.condition %cond, iter(%arg3: !torch.optional<tensor>)
} : (!torch.int, !torch.bool, !torch.optional<tensor>) -> (!torch.optional<tensor>)
return %ret: !torch.optional<tensor>
}
// -----

View File

@ -76,9 +76,9 @@ func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT_NEG1]]
// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList %[[T]], %[[DIMLIST]], %[[FALSE]], %[[NONE]]
// CHECK-SAME: : !torch.vtensor<*,si64>, !torch.list<!torch.int>, !torch.bool, !torch.none
// CHECK-SAME: : !torch.vtensor<*,si64>, !torch.list<int>, !torch.bool, !torch.none
// CHECK-SAME: -> !torch.vtensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
@ -87,8 +87,8 @@ func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor {
%none = torch.constant.none
%int0 = torch.constant.int 0
%int-1 = torch.constant.int -1
%dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<*,si64>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor
%dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list<int>
%ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<*,si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor
}
@ -123,15 +123,15 @@ func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32>
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
%none = torch.constant.none
%int2 = torch.constant.int 2
%sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%ret = torch.aten.zeros %sizesList, %none, %none, %none, %none : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor
%sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%ret = torch.aten.zeros %sizesList, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}
@ -152,14 +152,14 @@ func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[?,1,4],f32>, !torch.tensor<[2,3,4],f32>) -> !torch.list<!torch.tensor>
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<!torch.tensor>, !torch.int -> !torch.tensor<*,f32>
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[?,1,4],f32>, !torch.tensor<[2,3,4],f32>) -> !torch.list<tensor>
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
%int1 = torch.constant.int 1
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<!torch.tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<!torch.tensor>, !torch.int -> !torch.tensor
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
@ -315,34 +315,34 @@ func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
// -----
// CHECK-LABEL: func @torch.aten.tensor(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK-SAME: : !torch.list<!torch.list<!torch.float>>, !torch.none, !torch.none, !torch.bool
// CHECK-SAME: : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool
// CHECK-SAME: -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor(%t: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none
%false = torch.constant.bool false
%ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<!torch.list<!torch.float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
%ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.tensor$specified_dtype(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list<!torch.list<!torch.float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64>
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor$specified_dtype(%t: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
func @torch.aten.tensor$specified_dtype(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none
%int4 = torch.constant.int 4
%false = torch.constant.bool false
%ret = torch.aten.tensor %t, %int4, %none, %false : !torch.list<!torch.list<!torch.float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor
%ret = torch.aten.tensor %t, %int4, %none, %false : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}

View File

@ -9,9 +9,9 @@
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.tanh(%[[SHAPE]]) : (!torch.list<!torch.int>) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.tanh(%[[SHAPE]]) : (!torch.list<int>) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
@ -31,9 +31,9 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[VALUE:.*]] = torch.pseudo.aten.fill.Scalar %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[VALUE]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.fill.Scalar(%[[SHAPE]], %{{.*}}) : (!torch.list<!torch.int>, !torch.float) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.fill.Scalar(%[[SHAPE]], %{{.*}}) : (!torch.list<int>, !torch.float) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
@ -55,10 +55,10 @@ func @pseudo_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
// CHECK: %[[UNIFORM:.*]] = torch.pseudo.aten.uniform %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[NONE]] : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[UNIFORM]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[ANY:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.any
// CHECK: %[[SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.uniform(%[[ARG0_SHAPE]], %[[ARG1]], %[[ARG1]], %[[ANY]]) : (!torch.list<!torch.int>, !torch.float, !torch.float, !torch.any) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.uniform(%[[ARG0_SHAPE]], %[[ARG1]], %[[ARG1]], %[[ANY]]) : (!torch.list<int>, !torch.float, !torch.float, !torch.any) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.float) -> !torch.vtensor {
@ -84,11 +84,11 @@ func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.f
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[INT1]] : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[ADD]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[ARG1_SHAPE:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[ARG1_SHAPE:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[SCALAR_CONVERTED:.*]] = torch.aten.Float.Scalar %[[INT1]] : !torch.int -> !torch.float
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.add.Tensor(%[[ARG0_SHAPE]], %[[ARG1_SHAPE]], %[[SCALAR_CONVERTED]]) : (!torch.list<!torch.int>, !torch.list<!torch.int>, !torch.float) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.add.Tensor(%[[ARG0_SHAPE]], %[[ARG1_SHAPE]], %[[SCALAR_CONVERTED]]) : (!torch.list<int>, !torch.list<int>, !torch.float) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
@ -111,10 +111,10 @@ func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vten
// CHECK: %[[TOP_VALUES:.*]], %[[TOPK_INDICES:.*]] = torch.aten.topk %[[ARG]], %[[INT3]], %[[INT1]], %[[TRUE]], %[[TRUE]] : !torch.tensor, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.tensor, !torch.tensor
// CHECK: torch.shape.calculate.yield %[[TOP_VALUES]], %[[TOPK_INDICES]] : !torch.tensor, !torch.tensor
// CHECK: } shapes {
// CHECK: %[[ARG_SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.tensor -> !torch.list<!torch.int>
// CHECK: %[[TOPK_SHAPE_TUPLE:.*]] = call @__torch_mlir_shape_fn.aten.topk(%[[ARG_SHAPE]], %[[INT3]], %[[INT1]], %[[TRUE]], %[[TRUE]]) : (!torch.list<!torch.int>, !torch.int, !torch.int, !torch.bool, !torch.bool) -> !torch.tuple<!torch.list<!torch.int>, !torch.list<!torch.int>>
// CHECK: %[[TOPK_SHAPE:.*]]:2 = torch.prim.TupleUnpack %[[TOPK_SHAPE_TUPLE]] : !torch.tuple<!torch.list<!torch.int>, !torch.list<!torch.int>> -> !torch.list<!torch.int>, !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[TOPK_SHAPE]]#0, %[[TOPK_SHAPE]]#1 : !torch.list<!torch.int>, !torch.list<!torch.int>
// CHECK: %[[ARG_SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.tensor -> !torch.list<int>
// CHECK: %[[TOPK_SHAPE_TUPLE:.*]] = call @__torch_mlir_shape_fn.aten.topk(%[[ARG_SHAPE]], %[[INT3]], %[[INT1]], %[[TRUE]], %[[TRUE]]) : (!torch.list<int>, !torch.int, !torch.int, !torch.bool, !torch.bool) -> !torch.tuple<list<int>, list<int>>
// CHECK: %[[TOPK_SHAPE:.*]]:2 = torch.prim.TupleUnpack %[[TOPK_SHAPE_TUPLE]] : !torch.tuple<list<int>, list<int>> -> !torch.list<int>, !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[TOPK_SHAPE]]#0, %[[TOPK_SHAPE]]#1 : !torch.list<int>, !torch.list<int>
// CHECK: } : !torch.tensor, !torch.tensor
// CHECK: return %[[RESULTS:.*]]#0, %[[RESULTS]]#1 : !torch.tensor, !torch.tensor
@ -132,14 +132,14 @@ func @multiple_results(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[CONV:.*]] = torch.aten.conv2d %[[ARG0]], %[[ARG1]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.int -> !torch.vtensor
// CHECK: %[[CONV:.*]] = torch.aten.conv2d %[[ARG0]], %[[ARG1]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[CONV]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[SHAPE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[SHAPE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[DEREFINED:.*]] = torch.derefine %{{.*}} : !torch.none to !torch.optional<!torch.list<!torch.int>>
// CHECK: %[[SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.conv2d(%[[SHAPE0]], %[[SHAPE1]], %[[DEREFINED]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!torch.list<!torch.int>, !torch.list<!torch.int>, !torch.optional<!torch.list<!torch.int>>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[SHAPE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[SHAPE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[DEREFINED:.*]] = torch.derefine %{{.*}} : !torch.none to !torch.optional<list<int>>
// CHECK: %[[SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.conv2d(%[[SHAPE0]], %[[SHAPE1]], %[[DEREFINED]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
@ -148,10 +148,10 @@ func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vt
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%none = torch.constant.none
%24 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%25 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%26 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%29 = torch.aten.conv2d %arg0, %arg1, %none, %24, %25, %26, %int1 : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.int -> !torch.vtensor
%24 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%25 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%26 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%29 = torch.aten.conv2d %arg0, %arg1, %none, %24, %25, %26, %int1 : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor
return %29 : !torch.vtensor
}
@ -164,28 +164,28 @@ func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vt
// CHECK: %[[C1EM5:.*]] = torch.constant.float 1.000000e-05
// CHECK: %[[C1EM1:.*]] = torch.constant.float 1.000000e-01
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : !torch.vtensor to !torch.optional<!torch.vtensor>
// CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : !torch.vtensor to !torch.optional<vtensor>
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[BN:.*]] = torch.aten.batch_norm %[[ARG]], %[[DEREFINED]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]], %[[C1EM1]], %[[C1EM5]], %[[TRUE]] : !torch.vtensor, !torch.optional<!torch.vtensor>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor
// CHECK: %[[BN:.*]] = torch.aten.batch_norm %[[ARG]], %[[DEREFINED]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]], %[[C1EM1]], %[[C1EM5]], %[[TRUE]] : !torch.vtensor, !torch.optional<vtensor>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[BN]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[ARG_SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[ARG_SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[NONE2:.*]] = torch.constant.none
// CHECK: %[[IS:.*]] = torch.aten.__is__ %[[DEREFINED]], %[[NONE2]] : !torch.optional<!torch.vtensor>, !torch.none -> !torch.bool
// CHECK: %[[DEREFINED_OPTIONAL_SIZE:.*]] = torch.prim.If %[[IS]] -> (!torch.optional<!torch.list<!torch.int>>) {
// CHECK: %[[DEREFINE_NONE:.*]] = torch.derefine %[[NONE2]] : !torch.none to !torch.optional<!torch.list<!torch.int>>
// CHECK: torch.prim.If.yield %[[DEREFINE_NONE]] : !torch.optional<!torch.list<!torch.int>>
// CHECK: %[[IS:.*]] = torch.aten.__is__ %[[DEREFINED]], %[[NONE2]] : !torch.optional<vtensor>, !torch.none -> !torch.bool
// CHECK: %[[DEREFINED_OPTIONAL_SIZE:.*]] = torch.prim.If %[[IS]] -> (!torch.optional<list<int>>) {
// CHECK: %[[DEREFINE_NONE:.*]] = torch.derefine %[[NONE2]] : !torch.none to !torch.optional<list<int>>
// CHECK: torch.prim.If.yield %[[DEREFINE_NONE]] : !torch.optional<list<int>>
// CHECK: } else {
// CHECK: %[[DOWNCASTED:.*]] = torch.prim.unchecked_cast %[[DEREFINED]] : !torch.optional<!torch.vtensor> -> !torch.vtensor
// CHECK: %[[DOWNCASTED_SIZE:.*]] = torch.aten.size %[[DOWNCASTED]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[DEREFINE_DOWNCASTED_SIZE:.*]] = torch.derefine %[[DOWNCASTED_SIZE]] : !torch.list<!torch.int> to !torch.optional<!torch.list<!torch.int>>
// CHECK: torch.prim.If.yield %[[DEREFINE_DOWNCASTED_SIZE]] : !torch.optional<!torch.list<!torch.int>>
// CHECK: %[[DOWNCASTED:.*]] = torch.prim.unchecked_cast %[[DEREFINED]] : !torch.optional<vtensor> -> !torch.vtensor
// CHECK: %[[DOWNCASTED_SIZE:.*]] = torch.aten.size %[[DOWNCASTED]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[DEREFINE_DOWNCASTED_SIZE:.*]] = torch.derefine %[[DOWNCASTED_SIZE]] : !torch.list<int> to !torch.optional<list<int>>
// CHECK: torch.prim.If.yield %[[DEREFINE_DOWNCASTED_SIZE]] : !torch.optional<list<int>>
// CHECK: }
// CHECK: %[[DEREFINED_NONE1:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.list<!torch.int>>
// CHECK: %[[DEREFINED_NONE2:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.list<!torch.int>>
// CHECK: %[[DEREFINED_NONE3:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.list<!torch.int>>
// CHECK: %[[BN_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.batch_norm(%[[ARG_SIZE]], %[[DEREFINED_OPTIONAL_SIZE:.*]], %[[DEREFINED_NONE1]], %[[DEREFINED_NONE2]], %[[DEREFINED_NONE3]], %[[FALSE]], %[[C1EM1]], %[[C1EM5]], %[[TRUE]]) : (!torch.list<!torch.int>, !torch.optional<!torch.list<!torch.int>>, !torch.optional<!torch.list<!torch.int>>, !torch.optional<!torch.list<!torch.int>>, !torch.optional<!torch.list<!torch.int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[BN_SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[DEREFINED_NONE1:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>>
// CHECK: %[[DEREFINED_NONE2:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>>
// CHECK: %[[DEREFINED_NONE3:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>>
// CHECK: %[[BN_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.batch_norm(%[[ARG_SIZE]], %[[DEREFINED_OPTIONAL_SIZE:.*]], %[[DEREFINED_NONE1]], %[[DEREFINED_NONE2]], %[[DEREFINED_NONE3]], %[[FALSE]], %[[C1EM1]], %[[C1EM5]], %[[TRUE]]) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[BN_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch.vtensor {
@ -194,8 +194,8 @@ func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch
%float1.000000e-05 = torch.constant.float 1.000000e-05
%float1.000000e-01 = torch.constant.float 1.000000e-01
%none = torch.constant.none
%derefined_tensor = torch.derefine %arg0 : !torch.vtensor to !torch.optional<!torch.vtensor>
%0 = torch.aten.batch_norm %arg0, %derefined_tensor, %none, %none, %none, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor, !torch.optional<!torch.vtensor>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor
%derefined_tensor = torch.derefine %arg0 : !torch.vtensor to !torch.optional<vtensor>
%0 = torch.aten.batch_norm %arg0, %derefined_tensor, %none, %none, %none, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor, !torch.optional<vtensor>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor
return %0 : !torch.vtensor
}
@ -204,29 +204,29 @@ func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch
// CHECK-LABEL: func @adjust_shape_function_arg$list(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor) -> !torch.list<!torch.vtensor>
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor) -> !torch.list<vtensor>
// CHECK: %[[VAL_3:.*]] = torch.shape.calculate {
// CHECK: %[[VAL_4:.*]] = torch.aten.index.Tensor %[[ARG0]], %[[LIST]] : !torch.vtensor, !torch.list<!torch.vtensor> -> !torch.vtensor
// CHECK: %[[VAL_4:.*]] = torch.aten.index.Tensor %[[ARG0]], %[[LIST]] : !torch.vtensor, !torch.list<vtensor> -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[VAL_4]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[ADJUSTED_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.optional<!torch.list<!torch.int>>>
// CHECK: %[[LIST_SIZE:.*]] = torch.aten.len.t %[[LIST]] : !torch.list<!torch.vtensor> -> !torch.int
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[ADJUSTED_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<optional<list<int>>>
// CHECK: %[[LIST_SIZE:.*]] = torch.aten.len.t %[[LIST]] : !torch.list<vtensor> -> !torch.int
// CHECK: %[[CTRUE:.*]] = torch.constant.bool true
// CHECK: torch.prim.Loop %[[LIST_SIZE]], %[[CTRUE]], init() {
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int):
// CHECK: %[[UNADJUSTED_ELEMENT:.*]] = torch.aten.__getitem__.t %[[LIST]], %[[ITER_NUM]] : !torch.list<!torch.vtensor>, !torch.int -> !torch.vtensor
// CHECK: %[[UNADJUSTED_ELEMENT_SHAPE:.*]] = torch.aten.size %[[UNADJUSTED_ELEMENT]] : !torch.vtensor -> !torch.list<!torch.int>
// CHECK: %[[ADJUSTED_ELEMENT:.*]] = torch.derefine %[[UNADJUSTED_ELEMENT_SHAPE]] : !torch.list<!torch.int> to !torch.optional<!torch.list<!torch.int>>
// CHECK: %{{.*}} = torch.aten.append.t %[[ADJUSTED_LIST]], %[[ADJUSTED_ELEMENT]] : !torch.list<!torch.optional<!torch.list<!torch.int>>>, !torch.optional<!torch.list<!torch.int>> -> !torch.list<!torch.optional<!torch.list<!torch.int>>>
// CHECK: %[[UNADJUSTED_ELEMENT:.*]] = torch.aten.__getitem__.t %[[LIST]], %[[ITER_NUM]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor
// CHECK: %[[UNADJUSTED_ELEMENT_SHAPE:.*]] = torch.aten.size %[[UNADJUSTED_ELEMENT]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[ADJUSTED_ELEMENT:.*]] = torch.derefine %[[UNADJUSTED_ELEMENT_SHAPE]] : !torch.list<int> to !torch.optional<list<int>>
// CHECK: %{{.*}} = torch.aten.append.t %[[ADJUSTED_LIST]], %[[ADJUSTED_ELEMENT]] : !torch.list<optional<list<int>>>, !torch.optional<list<int>> -> !torch.list<optional<list<int>>>
// CHECK: torch.prim.Loop.condition %[[CTRUE]], iter()
// CHECK: } : (!torch.int, !torch.bool) -> ()
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.index.Tensor(%[[ARG0_SHAPE]], %[[ADJUSTED_LIST]]) : (!torch.list<!torch.int>, !torch.list<!torch.optional<!torch.list<!torch.int>>>) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.index.Tensor(%[[ARG0_SHAPE]], %[[ADJUSTED_LIST]]) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[VAL_15:.*]] : !torch.vtensor
func @adjust_shape_function_arg$list(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor) -> !torch.list<!torch.vtensor>
%1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor, !torch.list<!torch.vtensor> -> !torch.vtensor
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor, !torch.list<vtensor> -> !torch.vtensor
return %1 : !torch.vtensor
}

View File

@ -9,8 +9,8 @@
// CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor to !torch.vtensor<[2,?],unk>
// CHECK: torch.shape.calculate.yield %[[REFINED]] : !torch.vtensor<[2,?],unk>
// CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[ARG1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[ARG1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor<[2,?],unk>
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[2,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
@ -19,8 +19,8 @@ func @refine_shape_calculate_result$basic(%arg0: !torch.vtensor, %arg1: !torch.i
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int2, %arg1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
%1 = torch.prim.ListConstruct %int2, %arg1 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
@ -34,15 +34,15 @@ func @refine_shape_calculate_result$clobber_one_element(%arg0: !torch.vtensor, %
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If %arg2 -> () {
// Clobber element 0 of the list. So we can only know that the result is [?,2] instead of [2,2].
%2 = torch.aten._set_item.t %1, %int0, %arg1 : !torch.list<!torch.int>, !torch.int, !torch.int -> !torch.list<!torch.int>
%2 = torch.aten._set_item.t %1, %int0, %arg1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
@ -56,16 +56,16 @@ func @refine_shape_calculate_result$clobber_all_elements(%arg0: !torch.vtensor,
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
torch.prim.If %arg2 -> () {
// Set an unknown element of the list. This clobbers our knowledge of the whole contents of the list.
// So we can only know that the result is [?,?] instead of [2,2].
%2 = torch.aten._set_item.t %1, %arg1, %int0 : !torch.list<!torch.int>, !torch.int, !torch.int -> !torch.list<!torch.int>
%2 = torch.aten._set_item.t %1, %arg1, %int0 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
torch.prim.If.yield
}
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
@ -80,8 +80,8 @@ func @refine_shape_calculate_result$meet_with_existing_information(%arg0: !torch
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor<[?,3],f32>
} shapes {
%1 = torch.prim.ListConstruct %int2, %arg1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
%1 = torch.prim.ListConstruct %int2, %arg1 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor<[?,3],f32>
return %0 : !torch.vtensor<[?,3],f32>
}
@ -95,8 +95,8 @@ func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vt
%1 = torch.shape.calculate {
torch.shape.calculate.yield %0 : !torch.vtensor
} shapes {
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %2 : !torch.list<!torch.int>
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
torch.shape.calculate.yield.shapes %2 : !torch.list<int>
} : !torch.vtensor
%2 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
return %2 : !torch.vtensor
@ -104,7 +104,7 @@ func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vt
// CHECK-LABEL: func @fully_unroll_prim_loop$unroll(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<!torch.int>) -> !torch.vtensor {
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT0:.*]] = torch.constant.int 0
@ -114,10 +114,10 @@ func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vt
// CHECK: torch.prim.Print(%[[INT0]], %[[INT0]]) : !torch.int, !torch.int
// CHECK: torch.prim.Print(%[[INT1]], %[[INT0]]) : !torch.int, !torch.int
// CHECK: torch.prim.Print(%[[INT2]], %[[INT0]]) : !torch.int, !torch.int
// CHECK: torch.shape.calculate.yield.shapes %[[ARG1]] : !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[ARG1]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<!torch.int>) -> !torch.vtensor {
func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>) -> !torch.vtensor {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
@ -129,14 +129,14 @@ func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<!t
torch.prim.Print(%arg2, %arg3) : !torch.int, !torch.int
torch.prim.Loop.condition %true, iter(%arg3: !torch.int)
} : (!torch.int, !torch.bool, !torch.int) -> (!torch.int)
torch.shape.calculate.yield.shapes %arg1 : !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %arg1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @fully_unroll_prim_loop$no_unroll(
// CHECK: torch.prim.Loop
func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list<!torch.int>, %arg2: !torch.int) -> !torch.vtensor {
func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%0 = torch.shape.calculate {
@ -147,7 +147,7 @@ func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list
torch.prim.Print(%arg2) : !torch.int
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %arg1 : !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %arg1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
@ -156,24 +156,24 @@ func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int,
// CHECK-SAME: %[[ARG2:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$basic(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor {
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
%2 = torch.aten.append.t %1, %arg1 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
%3 = torch.aten.append.t %1, %arg2 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.append.t %1, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>
%3 = torch.aten.append.t %1, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// Test the different supported mutation ops.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mutation_ops(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int1, %arg1, %arg2, %arg3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int1, %arg1, %arg2, %arg3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
@ -182,11 +182,11 @@ func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%2 = torch.aten._set_item.t %1, %int1, %arg1 : !torch.list<!torch.int>, !torch.int, !torch.int -> !torch.list<!torch.int>
%3 = torch.aten.append.t %1, %arg2 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
torch.aten.insert.t %1, %int3, %arg3 : !torch.list<!torch.int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten._set_item.t %1, %int1, %arg1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
%3 = torch.aten.append.t %1, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.aten.insert.t %1, %int3, %arg3 : !torch.list<int>, !torch.int, !torch.int
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
@ -199,12 +199,12 @@ func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.v
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
%2 = torch.aten.append.t %1, %arg1 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.append.t %1, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>
// The value of the alias %2 is printed, but we don't handle that yet.
torch.prim.Print(%2) : !torch.list<!torch.int>
%3 = torch.aten.append.t %1, %arg2 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
torch.prim.Print(%2) : !torch.list<int>
%3 = torch.aten.append.t %1, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
@ -213,8 +213,8 @@ func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.v
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
@ -222,17 +222,17 @@ func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vt
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
// This readonly op in a loop doesn't block us from abstractly interpreting
// the whole block.
torch.prim.Loop %arg1, %true, init() {
^bb0(%arg3: !torch.int):
%2 = torch.aten.__getitem__.t %1, %int0 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
torch.prim.Print(%2) : !torch.list<!torch.int>
%2 = torch.aten.__getitem__.t %1, %int0 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Print(%2) : !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%2 = torch.aten.append.t %1, %int3 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
%2 = torch.aten.append.t %1, %int3 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
@ -247,24 +247,24 @@ func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtens
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %arg1, %true, init() {
^bb0(%arg3: !torch.int):
%2 = torch.aten.__getitem__.t %1, %int0 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
torch.prim.Print(%2) : !torch.list<!torch.int>
%2 = torch.aten.__getitem__.t %1, %int0 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Print(%2) : !torch.list<int>
// This mutation prevents us from abstractly interpreting.
%3 = torch.aten.append.t %1, %arg1 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
%3 = torch.aten.append.t %1, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%2 = torch.aten.append.t %1, %int3 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
%2 = torch.aten.append.t %1, %int3 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @abstractly_interpret_list_ops$miscompile$list_identity(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<!torch.int>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.bool) -> !torch.vtensor {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.shape.calculate {
@ -272,34 +272,34 @@ func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtens
// CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[3,3],unk>
// CHECK: } shapes {
// Notice this torch.prim.ListConstruct....
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_7:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list<!torch.int>) {
// CHECK: torch.prim.If.yield %[[VAL_6]] : !torch.list<!torch.int>
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list<int>) {
// CHECK: torch.prim.If.yield %[[VAL_6]] : !torch.list<int>
// CHECK: } else {
// CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list<!torch.int>
// CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list<int>
// CHECK: }
// .... and this one don't have the same object identity, but should!
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_9:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list<!torch.int>) {
// CHECK: torch.prim.If.yield %[[VAL_8]] : !torch.list<!torch.int>
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list<int>) {
// CHECK: torch.prim.If.yield %[[VAL_8]] : !torch.list<int>
// CHECK: } else {
// CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list<!torch.int>
// CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list<int>
// CHECK: }
// CHECK: %[[VAL_10:.*]] = torch.aten.__is__ %[[VAL_11:.*]], %[[VAL_12:.*]] : !torch.list<!torch.int>, !torch.list<!torch.int> -> !torch.bool
// CHECK: %[[VAL_10:.*]] = torch.aten.__is__ %[[VAL_11:.*]], %[[VAL_12:.*]] : !torch.list<int>, !torch.list<int> -> !torch.bool
// CHECK: torch.prim.Print(%[[VAL_10]]) : !torch.bool
// CHECK: torch.shape.calculate.yield.shapes %[[VAL_8]] : !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[VAL_8]] : !torch.list<int>
// CHECK: } : !torch.vtensor<[3,3],unk>
// CHECK: %[[VAL_13:.*]] = torch.tensor_static_info_cast %[[VAL_14:.*]] : !torch.vtensor<[3,3],unk> to !torch.vtensor
// CHECK: return %[[VAL_13]] : !torch.vtensor
func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtensor, %arg1: !torch.list<!torch.int>, %arg2: !torch.bool) -> !torch.vtensor {
func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
%2 = torch.aten.append.t %1, %int3 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.append.t %1, %int3 : !torch.list<int>, !torch.int -> !torch.list<int>
// TODO: Fix this miscompile!
// For the case where %arg2 is true, the resulting IR will miscompile
// because the abstract interpretation of the list ops will create two list
@ -311,20 +311,20 @@ func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtens
// replace a single list literal at a time, and bail out if there are any
// uses of the original list value that are not replaced by the created
// literal.
%3 = torch.prim.If %arg2 -> (!torch.list<!torch.int>) {
torch.prim.If.yield %1 : !torch.list<!torch.int>
%3 = torch.prim.If %arg2 -> (!torch.list<int>) {
torch.prim.If.yield %1 : !torch.list<int>
} else {
torch.prim.If.yield %arg1 : !torch.list<!torch.int>
torch.prim.If.yield %arg1 : !torch.list<int>
}
%4 = torch.aten.append.t %1, %int3 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
%5 = torch.prim.If %arg2 -> (!torch.list<!torch.int>) {
torch.prim.If.yield %1 : !torch.list<!torch.int>
%4 = torch.aten.append.t %1, %int3 : !torch.list<int>, !torch.int -> !torch.list<int>
%5 = torch.prim.If %arg2 -> (!torch.list<int>) {
torch.prim.If.yield %1 : !torch.list<int>
} else {
torch.prim.If.yield %arg1 : !torch.list<!torch.int>
torch.prim.If.yield %arg1 : !torch.list<int>
}
%6 = torch.aten.__is__ %3, %5 : !torch.list<!torch.int>, !torch.list<!torch.int> -> !torch.bool
%6 = torch.aten.__is__ %3, %5 : !torch.list<int>, !torch.list<int> -> !torch.bool
torch.prim.Print(%6) : !torch.bool
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}
@ -345,8 +345,8 @@ func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtens
// CHECK: } shapes {
// CHECK: %[[SIZE0:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[?,?],unk>, !torch.int -> !torch.int
// CHECK: %[[SIZE1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[?,?],unk>, !torch.int -> !torch.int
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE0]], %[[SIZE1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<!torch.int>
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE0]], %[[SIZE1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor<[?,?],unk>
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[?,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
@ -356,15 +356,15 @@ func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],unk> -> !torch.vtensor
torch.shape.calculate.yield %1 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%2 = torch.aten.dim %arg0 : !torch.vtensor<[?,?],unk> -> !torch.int
torch.prim.Loop %2, %true, init() {
^bb0(%arg1: !torch.int):
%3 = torch.aten.size.int %arg0, %arg1 : !torch.vtensor<[?,?],unk>, !torch.int -> !torch.int
%4 = torch.aten.append.t %1, %3 : !torch.list<!torch.int>, !torch.int -> !torch.list<!torch.int>
%4 = torch.aten.append.t %1, %3 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.shape.calculate.yield.shapes %1 : !torch.list<!torch.int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}

View File

@ -20,16 +20,16 @@ class TestModule(torch.nn.Module):
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
# CHECK: torch.attr "training" : !torch.bool
# CHECK: torch.attr "_is_full_backward_hook" : !torch.optional<!torch.bool>
# CHECK: torch.attr "d" : !torch.dict<!torch.str, !torch.tensor>
# CHECK: torch.attr "_is_full_backward_hook" : !torch.optional<bool>
# CHECK: torch.attr "d" : !torch.dict<str, tensor>
# CHECK: }
# CHECK: %[[K:.*]] = torch.constant.str "key1"
# CHECK: %[[TENSOR:.*]] = torch.tensor.literal(dense<1> : tensor<si64>) : !torch.tensor<[],si64>
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
# CHECK-SAME keys(%[[K]] : !torch.str) values(%[[TENSOR]] : !torch.tensor<[],si64>)
# CHECK-SAME: -> !torch.dict<!torch.str, !torch.tensor>
# CHECK-SAME: -> !torch.dict<str, tensor>
# CHECK: torch.nn_module {
# CHECK: torch.slot "d", %[[DICT]] : !torch.dict<!torch.str, !torch.tensor>
# CHECK: torch.slot "d", %[[DICT]] : !torch.dict<str, tensor>
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
test_module = TestModule()

View File

@ -16,13 +16,13 @@ class TestModule(torch.nn.Module):
super().__init__()
self.l = [1, 2]
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
# CHECK: torch.attr "l" : !torch.list<!torch.int>
# CHECK: torch.attr "l" : !torch.list<int>
# CHECK: }
# CHECK: %[[N1:.*]] = torch.constant.int 1
# CHECK: %[[N2:.*]] = torch.constant.int 2
# CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[N1]], %[[N2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[N1]], %[[N2]] : (!torch.int, !torch.int) -> !torch.list<int>
# CHECK: torch.nn_module {
# CHECK: torch.slot "l", %[[LIST]] : !torch.list<!torch.int>
# CHECK: torch.slot "l", %[[LIST]] : !torch.list<int>
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">

View File

@ -18,11 +18,11 @@ class TestModule(torch.nn.Module):
super().__init__()
# CHECK-LABEL: func private @__torch__.TestModule.forward(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional<!torch.int> {
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional<int> {
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.int>
# CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[SELF]]["callee"] (%[[DEREFINED]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.optional<!torch.int>) -> !torch.optional<!torch.int>
# CHECK: return %[[RET]] : !torch.optional<!torch.int>
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<int>
# CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[SELF]]["callee"] (%[[DEREFINED]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.optional<int>) -> !torch.optional<int>
# CHECK: return %[[RET]] : !torch.optional<int>
def forward(self):
return self.callee(None)
def callee(self, o: typing.Optional[int]):

View File

@ -22,7 +22,7 @@ class TestModule(torch.nn.Module):
# CHECK: %[[N2:.*]] = torch.constant.int 2
# CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[N1]], %[[N2]] : !torch.int, !torch.int
# CHECK: torch.nn_module {
# CHECK: torch.slot "t", %[[TUPLE]] : !torch.tuple<!torch.int, !torch.int>
# CHECK: torch.slot "t", %[[TUPLE]] : !torch.tuple<int, int>
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">

View File

@ -12,9 +12,9 @@ from typing import Tuple, Optional, List, NamedTuple, Dict
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.dict_literal_empty() -> !torch.dict<!torch.str, !torch.tensor> {
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict<!torch.str, !torch.tensor>
# CHECK: return %[[DICT]] : !torch.dict<!torch.str, !torch.tensor>
# CHECK-LABEL: func @__torch__.dict_literal_empty() -> !torch.dict<str, tensor> {
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict<str, tensor>
# CHECK: return %[[DICT]] : !torch.dict<str, tensor>
@mb.import_function
@torch.jit.script
def dict_literal_empty() -> Dict[str, torch.Tensor]:
@ -24,12 +24,12 @@ def dict_literal_empty() -> Dict[str, torch.Tensor]:
# CHECK-LABEL: func @__torch__.dict_literal(
# CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
# CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor)
# CHECK-SAME: -> !torch.dict<!torch.str, !torch.optional<!torch.tensor>> {
# CHECK-SAME: -> !torch.dict<str, optional<tensor>> {
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
# CHECK-SAME: keys(%[[K0]], %[[K1]] : !torch.str, !torch.str)
# CHECK-SAME: values(%[[V0]], %[[V1]] : !torch.tensor, !torch.tensor) ->
# CHECK-SAME: !torch.dict<!torch.str, !torch.optional<!torch.tensor>>
# CHECK: return %[[DICT]] : !torch.dict<!torch.str, !torch.optional<!torch.tensor>>
# CHECK-SAME: !torch.dict<str, optional<tensor>>
# CHECK: return %[[DICT]] : !torch.dict<str, optional<tensor>>
# CHECK: }
@mb.import_function
@torch.jit.script

View File

@ -12,16 +12,16 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<!torch.int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<!torch.int>
# CHECK: return %[[RET]] : !torch.optional<!torch.int>
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
# CHECK: return %[[RET]] : !torch.optional<int>
@mb.import_function
@torch.jit.script
def optional_return(i: int) -> typing.Optional[int]:
return i
# CHECK-LABEL: func @__torch__.optional_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<!torch.int>) -> !torch.none {
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<int>) -> !torch.none {
@mb.import_function
@torch.jit.script
def optional_arg(i: typing.Optional[int]) -> None:
@ -29,9 +29,9 @@ def optional_arg(i: typing.Optional[int]) -> None:
# CHECK-LABEL: func @__torch__.calls_optional_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.none {
# CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<!torch.int>) -> !torch.none
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<!torch.int>
# CHECK: %{{.*}} = call_indirect %[[CALLEE]](%[[DEREFINED]]) : (!torch.optional<!torch.int>) -> !torch.none
# CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<int>) -> !torch.none
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
# CHECK: %{{.*}} = call_indirect %[[CALLEE]](%[[DEREFINED]]) : (!torch.optional<int>) -> !torch.none
@mb.import_function
@torch.jit.script
def calls_optional_arg(i: int):

View File

@ -34,16 +34,16 @@ def prim_If(b: bool, i: int):
# CHECK-LABEL: func @__torch__.prim_If_derefine(
# CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.optional<!torch.int> {
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[RES:.*]] = torch.prim.If %[[B]] -> (!torch.optional<!torch.int>) {
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.int>
# CHECK: torch.prim.If.yield %[[NONE_DEREFINED]] : !torch.optional<!torch.int>
# CHECK: %[[RES:.*]] = torch.prim.If %[[B]] -> (!torch.optional<int>) {
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<int>
# CHECK: torch.prim.If.yield %[[NONE_DEREFINED]] : !torch.optional<int>
# CHECK: } else {
# CHECK: %[[I_DEREFINED:.*]] = torch.derefine %[[I]] : !torch.int to !torch.optional<!torch.int>
# CHECK: torch.prim.If.yield %[[I_DEREFINED]] : !torch.optional<!torch.int>
# CHECK: %[[I_DEREFINED:.*]] = torch.derefine %[[I]] : !torch.int to !torch.optional<int>
# CHECK: torch.prim.If.yield %[[I_DEREFINED]] : !torch.optional<int>
# CHECK: }
# CHECK: return %[[RES:.*]] : !torch.optional<!torch.int>
# CHECK: return %[[RES:.*]] : !torch.optional<int>
@mb.import_function
@torch.jit.script
def prim_If_derefine(b: bool, i: int):

View File

@ -11,9 +11,9 @@ mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.list<!torch.tensor> {
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list<!torch.tensor>
# CHECK: return %[[RET]] : !torch.list<!torch.tensor>
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.list<tensor> {
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list<tensor>
# CHECK: return %[[RET]] : !torch.list<tensor>
@mb.import_function
@torch.jit.script

View File

@ -50,16 +50,16 @@ def prim_Loop_whilelike(n: int):
return f
# CHECK-LABEL: func @__torch__.prim_Loop_derefine(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<!torch.int> {
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[TRUE:.*]] = torch.constant.bool true
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.int>
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<int>
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[ARG]], %[[TRUE]], init(%[[NONE_DEREFINED]]) {
# CHECK: ^bb0(%[[IV:.*]]: !torch.int, %[[X_ITER:.*]]: !torch.optional<!torch.int>):
# CHECK: %[[X_NEXT:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<!torch.int>
# CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[X_NEXT]] : !torch.optional<!torch.int>)
# CHECK: } : (!torch.int, !torch.bool, !torch.optional<!torch.int>) -> !torch.optional<!torch.int>
# CHECK: return %[[RET:.*]] : !torch.optional<!torch.int>
# CHECK: ^bb0(%[[IV:.*]]: !torch.int, %[[X_ITER:.*]]: !torch.optional<int>):
# CHECK: %[[X_NEXT:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
# CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[X_NEXT]] : !torch.optional<int>)
# CHECK: } : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>
# CHECK: return %[[RET:.*]] : !torch.optional<int>
@mb.import_function
@torch.jit.script
def prim_Loop_derefine(n: int):

View File

@ -50,14 +50,14 @@ def prim_RaiseException():
raise Exception("Error")
# CHECK-LABEL: func @__torch__.prim_unchecked_cast(
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<!torch.int>) -> !torch.int {
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<int>) -> !torch.int {
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[C3:.*]] = torch.constant.int 3
# CHECK: %[[IS_NONE:.*]] = torch.aten.__is__ %[[ARG]], %[[NONE]] : !torch.optional<!torch.int>, !torch.none -> !torch.bool
# CHECK: %[[IS_NONE:.*]] = torch.aten.__is__ %[[ARG]], %[[NONE]] : !torch.optional<int>, !torch.none -> !torch.bool
# CHECK: %[[RESULT:.*]] = torch.prim.If %[[IS_NONE]] -> (!torch.int) {
# CHECK: torch.prim.If.yield %[[C3]] : !torch.int
# CHECK: } else {
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[ARG]] : !torch.optional<!torch.int> -> !torch.int
# CHECK: %[[CASTED:.*]] = torch.prim.unchecked_cast %[[ARG]] : !torch.optional<int> -> !torch.int
# CHECK: torch.prim.If.yield %[[CASTED]] : !torch.int
# CHECK: }
# CHECK: return %[[RESULT:.*]] : !torch.int
@ -69,8 +69,8 @@ def prim_unchecked_cast(i: typing.Optional[int]):
return i
# CHECK-LABEL: func @__torch__.prim_TupleUnpack(
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<!torch.int, !torch.int>) -> !torch.int {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !torch.tuple<!torch.int, !torch.int> -> !torch.int, !torch.int
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<int, int>) -> !torch.int {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !torch.tuple<int, int> -> !torch.int, !torch.int
# CHECK: return %[[RET]]#0 : !torch.int
@mb.import_function
@torch.jit.script
@ -79,8 +79,8 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
return val
# CHECK-LABEL: func @__torch__.prim_TupleIndex(
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<!torch.tensor, !torch.tensor>) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<tensor, tensor>) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
# CHECK: return %[[RET]] : !torch.tensor
@mb.import_function
@torch.jit.script
@ -88,8 +88,8 @@ def prim_TupleIndex(tup: typing.Tuple[torch.Tensor, torch.Tensor]):
return tup[0]
# CHECK-LABEL: func @__torch__.prim_ListUnpack(
# CHECK-SAME: %[[ARG:.*]]: !torch.list<!torch.int>) -> !torch.int {
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !torch.list<!torch.int> -> !torch.int, !torch.int
# CHECK-SAME: %[[ARG:.*]]: !torch.list<int>) -> !torch.int {
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !torch.list<int> -> !torch.int, !torch.int
# CHECK: return %[[RET]]#1 : !torch.int
@mb.import_function
@torch.jit.script
@ -125,40 +125,40 @@ def prim_device(x):
return x.device
# CHECK-LABEL: func @__torch__.prim_min(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<!torch.int, !torch.int, !torch.int> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<!torch.int>
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<!torch.int> -> !torch.int
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<int, int, int> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<int>
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<int> -> !torch.int
# CHECK: %[[MIN2:.*]] = torch.prim.min.int %[[ARG]], %[[ARG]] : !torch.int, !torch.int -> !torch.int
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !torch.list<!torch.int> -> !torch.int
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !torch.list<int> -> !torch.int
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[MIN1]], %[[MIN2]], %[[MIN3]] : !torch.int, !torch.int, !torch.int
# CHECK: return %[[RET]] : !torch.tuple<!torch.int, !torch.int, !torch.int>
# CHECK: return %[[RET]] : !torch.tuple<int, int, int>
@mb.import_function
@torch.jit.script
def prim_min(x: int):
return min(x), min(x,x), min(x, x, x)
# CHECK-LABEL: func @__torch__.prim_max(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<!torch.int, !torch.int, !torch.int> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<!torch.int>
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<!torch.int> -> !torch.int
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<int, int, int> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<int>
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<int> -> !torch.int
# CHECK: %[[MAX2:.*]] = torch.prim.max.int %[[ARG]], %[[ARG]] : !torch.int, !torch.int -> !torch.int
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !torch.list<!torch.int> -> !torch.int
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !torch.list<int> -> !torch.int
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[MAX1]], %[[MAX2]], %[[MAX3]] : !torch.int, !torch.int, !torch.int
# CHECK: return %[[RET]] : !torch.tuple<!torch.int, !torch.int, !torch.int>
# CHECK: return %[[RET]] : !torch.tuple<int, int, int>
@mb.import_function
@torch.jit.script
def prim_max(x: int):
return max(x), max(x,x), max(x, x, x)
# CHECK-LABEL: func @__torch__.prim_Constant_list() -> !torch.list<!torch.int> {
# CHECK-LABEL: func @__torch__.prim_Constant_list() -> !torch.list<int> {
# CHECK: %[[A:.*]] = torch.constant.int 1
# CHECK: %[[B:.*]] = torch.constant.int 2
# CHECK: %[[C:.*]] = torch.constant.int 3
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[A]], %[[B]], %[[C]] :
# CHECK-SAME: (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: return %[[RET]] : !torch.list<!torch.int>
# CHECK-SAME: (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
# CHECK: return %[[RET]] : !torch.list<int>
import_ts_ir('__torch__.prim_Constant_list', '''graph():
%list : int[] = prim::Constant[value=[1, 2, 3]]()
return (%list)''')

View File

@ -16,10 +16,10 @@ NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]),
# CHECK-LABEL: func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor> {
# CHECK-SAME: !torch.tuple<tensor, tensor> {
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
# CHECK: return %[[RET]] : !torch.tuple<!torch.tensor, !torch.tensor>
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
# CHECK: return %[[RET]] : !torch.tuple<tensor, tensor>
@mb.import_function
@ -31,13 +31,13 @@ def tuple(t0, t1):
# CHECK-LABEL: func @__torch__.tuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>> {
# CHECK: %[[TNEW:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
# CHECK: %[[RET:.*]] = torch.derefine %[[TNEW]] :
# CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor> to
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
# CHECK: return %[[RET]] : !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
# CHECK-SAME: !torch.tuple<tensor, tensor> to
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>>
# CHECK: return %[[RET]] : !torch.tuple<optional<tensor>, optional<tensor>>
@mb.import_function
@ -50,11 +50,11 @@ def tuple_optional(
# CHECK-LABEL: func @__torch__.namedtuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>> {
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
# CHECK-SAME: !torch.tensor, !torch.tensor ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
# CHECK: return %[[RET]] : !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>>
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>>
# CHECK: return %[[RET]] : !torch.tuple<optional<tensor>, optional<tensor>>
# CHECK: }
#
@mb.import_function