Update llvm-project to 830c0b9023cd0cf91955900e0d96283e7a8c3711

- builder.getSymbolRefAttr is gone.
- OpAsmOpInterface's getAsmResultNames method needs explicit override
- a bunch of churn for builtin.func needing to be made explicit (and
  sometimes implicit?)
- operation printers no longer need to print the operation name
  themselves.
- snuck in beneficial trivial addition to TmpDeleteDeadIREEListsPass to
  test a particular upstream change e2e with my local patchset.
pull/298/head
Sean Silva 2021-09-03 18:38:00 +00:00
parent 9cc4fdcaa8
commit 1dec561cfd
21 changed files with 184 additions and 175 deletions

View File

@ -43,7 +43,7 @@ class IREE_AliasedSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
let storageType = [{ FlatSymbolRefAttr }];
let returnType = [{ StringRef }];
let valueType = NoneType;
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
let constBuilderCall = "::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
}
class IREE_AnyPtrOf<list<Type> types> :

@ -1 +1 @@
Subproject commit a8de667af092c9b4b3b4a95827a521602ebf14ed
Subproject commit 830c0b9023cd0cf91955900e0d96283e7a8c3711

View File

@ -12,7 +12,7 @@ from typing import Tuple, Optional, List, NamedTuple, Dict
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: builtin.func @__torch__.dict_literal_empty() -> !torch.dict<!torch.str, !torch.tensor> {
# CHECK-LABEL: func @__torch__.dict_literal_empty() -> !torch.dict<!torch.str, !torch.tensor> {
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict<!torch.str, !torch.tensor>
# CHECK: return %[[DICT]] : !torch.dict<!torch.str, !torch.tensor>
@mb.import_function
@ -21,7 +21,7 @@ def dict_literal_empty() -> Dict[str, torch.Tensor]:
return {}
# CHECK-LABEL: builtin.func @__torch__.dict_literal(
# CHECK-LABEL: func @__torch__.dict_literal(
# CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
# CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor)
# CHECK-SAME: -> !torch.dict<!torch.str, !torch.optional<!torch.tensor>> {

View File

@ -13,7 +13,7 @@ mb = torch_mlir.ModuleBuilder()
NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]),
('f2', Optional[torch.Tensor])])
# CHECK-LABEL: builtin.func @__torch__.tuple(
# CHECK-LABEL: func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor> {
@ -28,7 +28,7 @@ def tuple(t0, t1):
return t0, t1
# CHECK-LABEL: builtin.func @__torch__.tuple_optional(
# CHECK-LABEL: func @__torch__.tuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {
@ -47,7 +47,7 @@ def tuple_optional(
return t0, t1
# CHECK-LABEL: builtin.func @__torch__.namedtuple_optional(
# CHECK-LABEL: func @__torch__.namedtuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {

View File

@ -95,7 +95,8 @@ def CompareOperationAttr : StrEnumAttr<
//===----------------------------------------------------------------------===//
def Basicpy_NumericConstantOp : Basicpy_Op<"numeric_constant", [
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "A constant from the Python3 numeric type hierarchy";
let description = [{
Basicpy re-uses core MLIR types to represent the Python3 numeric type
@ -141,7 +142,8 @@ def Basicpy_NumericConstantOp : Basicpy_Op<"numeric_constant", [
}
def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "A boolean constant";
let description = [{
A constant of type !basicpy.BoolType that can take either an i1 value
@ -220,7 +222,8 @@ def Basicpy_BuildTupleOp : Basicpy_Op<"build_tuple", [NoSideEffect]> {
}
def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Constant bytes value";
let description = [{
A bytes value of BytesType. The value is represented by a StringAttr.
@ -251,7 +254,8 @@ def Basicpy_SingletonOp : Basicpy_Op<"singleton", [
}
def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Constant string value";
let description = [{
A string value of StrType. The value is represented by a StringAttr

View File

@ -534,7 +534,8 @@ def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
//===----------------------------------------------------------------------===//
def Torch_ConstantNoneOp : Torch_Op<"constant.none",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Get the singleton None value.";
let description = [{
Not to be confused with the `mlir::NoneType`. Be careful to use
@ -547,7 +548,8 @@ def Torch_ConstantNoneOp : Torch_Op<"constant.none",
}
def Torch_ConstantStrOp : Torch_Op<"constant.str",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Materialize a constant str value.";
let description = [{
Note: Strings in Python (and TorchScript) are immutable.
@ -563,7 +565,8 @@ def Torch_ConstantStrOp : Torch_Op<"constant.str",
}
def Torch_ConstantIntOp : Torch_Op<"constant.int",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Materialize a constant `int` value.";
let description = [{
Note: TorchScript represents integers as 64-bit signed values, unlike
@ -581,7 +584,8 @@ def Torch_ConstantIntOp : Torch_Op<"constant.int",
}
def Torch_ConstantFloatOp : Torch_Op<"constant.float",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Materialize a constant `float` value.";
let description = [{
Note: TorchScript represents `float` as 64-bit floating point values.
@ -599,7 +603,8 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float",
}
def Torch_ConstantBoolOp : Torch_Op<"constant.bool",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Materialize a constant `bool` value.";
let description = [{
}];

View File

@ -62,7 +62,7 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
let parser = [{
if (parser.parseLess())
return Type();
StringRef className;
std::string className;
if ($_parser.parseOptionalString(&className))
return Type();
if ($_parser.parseGreater())

View File

@ -80,7 +80,7 @@ static ParseResult parseNumericConstantOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, NumericConstantOp op) {
p << "basicpy.numeric_constant ";
p << " ";
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
if (op->getAttrs().size() > 1)
@ -176,7 +176,6 @@ static ParseResult parseExecOp(OpAsmParser &parser, OperationState *result) {
}
static void print(OpAsmPrinter &p, ExecOp op) {
p << op.getOperationName();
p.printOptionalAttrDictWithKeyword(op->getAttrs());
p.printRegion(op.body());
}
@ -230,7 +229,7 @@ static ParseResult parseFuncTemplateOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, FuncTemplateOp op) {
p << op.getOperationName() << " ";
p << " ";
p.printSymbolName(op.getName());
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
{SymbolTable::getSymbolAttrName()});
@ -294,7 +293,7 @@ static void print(OpAsmPrinter &p, SlotObjectMakeOp op) {
return;
}
p << op.getOperationName() << "(";
p << "(";
p.printOperands(op.slots());
p << ")";
p.printOptionalAttrDict(op->getAttrs(), {"className"});
@ -358,7 +357,7 @@ static void print(OpAsmPrinter &p, SlotObjectGetOp op) {
return;
}
p << op.getOperationName() << " ";
p << " ";
p.printOperand(op.object());
p << "[" << op.index() << "]";
p.printOptionalAttrDict(op->getAttrs(), {"index"});

View File

@ -21,7 +21,6 @@ using namespace mlir::NPCOMP::refbackrt;
//===----------------------------------------------------------------------===//
static void printModuleMetadataOp(OpAsmPrinter &p, ModuleMetadataOp &op) {
p << "refbackrt.module_metadata";
p.printOptionalAttrDictWithKeyword(op->getAttrs());
p.printRegion(op.metadatas(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);

View File

@ -54,7 +54,8 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
//===----------------------------------------------------------------------===//
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto func = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, function());
auto func =
symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, functionAttr());
if (!func)
return emitError() << "'@" << function()
<< "' does not reference a valid function";
@ -132,8 +133,8 @@ bool isValidSubtype(Type subtype, Type type) {
}
LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto classType =
symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(*this, getClassName());
auto classType = symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(
*this, SymbolRefAttr::get(getContext(), getClassName()));
if (!classType)
return emitError() << "'" << getClassName()
<< "' does not reference a valid class type";
@ -297,7 +298,7 @@ static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
}
static void print(OpAsmPrinter &p, PrimIfOp op) {
p << PrimIfOp::getOperationName() << " " << op.condition();
p << " " << op.condition();
p << " -> (" << op.getResultTypes() << ")";
p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false);
p << " else";
@ -748,7 +749,7 @@ static ParseResult parseConstantIntOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, Torch::ConstantIntOp op) {
p << Torch::ConstantIntOp::getOperationName() << " ";
p << " ";
p << op.value().getSExtValue();
p.printOptionalAttrDict(op->getAttrs(), {"value"});
}

View File

@ -33,7 +33,7 @@ class TmpDeleteDeadIREEListsPass
SmallVector<Operation *> deadOps;
deadOps.push_back(op);
for (auto &use : op.getResult().getUses()) {
if (isa<iree::ListSetOp>(use.getOwner())) {
if (isa<iree::ListSetOp, iree::ListResizeOp>(use.getOwner())) {
deadOps.push_back(use.getOwner());
} else {
// We can't analyze the list op if it is used by something else.

View File

@ -735,7 +735,8 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
auto wrapper = createWrapperFunc(originalFunc);
op.getResult().setType(LLVMPointerType::get(wrapper.getType()));
Builder builder(op.getContext());
op->setAttr("global_name", builder.getSymbolRefAttr(wrapper.getName()));
op->setAttr("global_name",
SymbolRefAttr::get(builder.getContext(), wrapper.getName()));
});
}
};

View File

@ -231,9 +231,9 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
// Add attributes that are valid for every func (funcName, numInputs,
// numOutputs)
namedAttrs.push_back(
std::make_pair(Identifier::get("funcName", module.getContext()),
builder.getSymbolRefAttr(func.getName())));
namedAttrs.push_back(std::make_pair(
Identifier::get("funcName", module.getContext()),
SymbolRefAttr::get(builder.getContext(), func.getName())));
namedAttrs.push_back(
std::make_pair(Identifier::get("numInputs", module.getContext()),
builder.getI32IntegerAttr(func.getNumArguments())));

View File

@ -1,7 +1,7 @@
// RUN: npcomp-opt <%s -convert-torch-to-iree -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: builtin.func @forward(
// CHECK-LABEL: func @forward(
// CHECK-SAME: %[[ARG_TORCH:.*]]: !torch.float) -> !torch.list<!torch.float> {
// CHECK: %[[ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]
// CHECK: %[[ALSO_ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]

View File

@ -36,7 +36,7 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,
// -----
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
@ -53,7 +53,7 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2
// -----
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>

View File

@ -1,6 +1,6 @@
// RUN: npcomp-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: builtin.func @forward
// CHECK-LABEL: func @forward
builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2

View File

@ -5,7 +5,7 @@
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @positional
func @positional(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @positional(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw []
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw [] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
@ -13,7 +13,7 @@ func @positional(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) ->
// -----
// CHECK-LABEL: func @kwValid
func @kwValid(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @kwValid(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second"]
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
@ -21,7 +21,7 @@ func @kwValid(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !ba
// -----
// CHECK-LABEL: func @posArgPack
func @posArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @posArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*"]
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
@ -29,28 +29,28 @@ func @posArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) ->
// -----
// CHECK-LABEL: func @kwArgPack
func @kwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @kwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**"]
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
// -----
func @kwOverflow(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @kwOverflow(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// expected-error @+1 {{expected <= kw arg names vs args}}
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second", "third", "fourth"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
// -----
func @badPosArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @badPosArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// expected-error @+1 {{positional arg pack must be the first kw arg}}
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*", "*"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
// -----
func @badKwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @badKwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// expected-error @+1 {{kw arg pack must be the last kw arg}}
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**", "next"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
@ -62,20 +62,20 @@ func @badKwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -
// -----
// CHECK-LABEL: module @valid_template
module @valid_template {
builtin.module @valid_template {
// CHECK: basicpy.func_template @__global$pkg.foobar attributes {py_bind = ["#abs"]} {
basicpy.func_template @__global$pkg.foobar attributes {py_bind = ["#abs"]} {
// CHECK: func @forInts(%arg0: i32) -> i32
func @forInts(%arg0 : i32) -> i32 {
builtin.func @forInts(%arg0 : i32) -> i32 {
return %arg0 : i32
}
}
}
// -----
module @invalid_template {
builtin.module @invalid_template {
basicpy.func_template @__global$pkg.foobar {
// expected-error @+1 {{illegal operation in func_template}}
module {}
builtin.module {}
}
}

View File

@ -342,7 +342,7 @@ func @torch.prim.If$erase_dead_branch(%arg0: !torch.int) -> !torch.int {
return %0 : !torch.int
}
// CHECK-LABEL: builtin.func @torch.prim.TupleUnpack(
// CHECK-LABEL: func @torch.prim.TupleUnpack(
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[ARG0]] : !torch.tensor
@ -353,7 +353,7 @@ func @torch.prim.TupleUnpack(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !tor
}
// CHECK-LABEL: builtin.func @torch.aten.__contains__.str(
// CHECK-LABEL: func @torch.aten.__contains__.str(
// CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
// CHECK-SAME: %[[K1:.*]]: !torch.str,
// CHECK-SAME: %[[V1:.*]]: !torch.tensor) -> !torch.bool {
@ -369,7 +369,7 @@ func @torch.aten.__contains__.str(%k0 : !torch.str, %v0: !torch.tensor, %k1: !to
return %pred : !torch.bool
}
// CHECK-LABEL: builtin.func @torch.aten.__contains__.str$with_dict_modified(
// CHECK-LABEL: func @torch.aten.__contains__.str$with_dict_modified(
// CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
// CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor) -> !torch.bool {
// CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
@ -389,7 +389,7 @@ func @torch.aten.__contains__.str$with_dict_modified(%k0 : !torch.str, %v0: !tor
return %pred : !torch.bool
}
// CHECK-LABEL: builtin.func @torch.aten.__getitem__.Dict_str(
// CHECK-LABEL: func @torch.aten.__getitem__.Dict_str(
// CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
// CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[DICT:.*]] = torch.prim.DictConstruct
@ -403,7 +403,7 @@ func @torch.aten.__getitem__.Dict_str(%k0 : !torch.str, %v0: !torch.tensor, %k1:
return %v : !torch.tensor
}
// CHECK-LABEL: builtin.func @torch.aten.add.int() -> !torch.int {
// CHECK-LABEL: func @torch.aten.add.int() -> !torch.int {
// CHECK: %[[CST9:.*]] = torch.constant.int 9
// CHECK: return %[[CST9]] : !torch.int
// CHECK: }
@ -414,7 +414,7 @@ func @torch.aten.add.int() -> !torch.int {
return %ret : !torch.int
}
// CHECK-LABEL: builtin.func @torch.aten.sub.int() -> !torch.int {
// CHECK-LABEL: func @torch.aten.sub.int() -> !torch.int {
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: return %[[CST1]] : !torch.int
// CHECK: }
@ -425,7 +425,7 @@ func @torch.aten.sub.int() -> !torch.int {
return %ret : !torch.int
}
// CHECK-LABEL: builtin.func @torch.aten.mul.int() -> !torch.int {
// CHECK-LABEL: func @torch.aten.mul.int() -> !torch.int {
// CHECK: %[[CST30:.*]] = torch.constant.int 30
// CHECK: return %[[CST30]] : !torch.int
// CHECK: }
@ -436,7 +436,7 @@ func @torch.aten.mul.int() -> !torch.int {
return %ret : !torch.int
}
// CHECK-LABEL: builtin.func @torch.aten.mul.int$with_zero() -> !torch.int {
// CHECK-LABEL: func @torch.aten.mul.int$with_zero() -> !torch.int {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: return %[[CST0]] : !torch.int
// CHECK: }
@ -447,7 +447,7 @@ func @torch.aten.mul.int$with_zero() -> !torch.int {
return %ret : !torch.int
}
// CHECK-LABEL: builtin.func @torch.aten.floordiv.int() -> !torch.int {
// CHECK-LABEL: func @torch.aten.floordiv.int() -> !torch.int {
// CHECK: %[[CST3:.*]] = torch.constant.int 3
// CHECK: return %[[CST3]] : !torch.int
// CHECK: }
@ -458,7 +458,7 @@ func @torch.aten.floordiv.int() -> !torch.int {
return %ret : !torch.int
}
// CHECK-LABEL: builtin.func @torch.aten.remainder.int() -> !torch.int {
// CHECK-LABEL: func @torch.aten.remainder.int() -> !torch.int {
// CHECK: %[[CST3:.*]] = torch.constant.int 3
// CHECK: return %[[CST3]] : !torch.int
// CHECK: }

View File

@ -5,7 +5,7 @@
torch.class_type @c {}
%0 = torch.nn_module {
// expected-error @+1 {{'builtin.func' op is not allowed inside 'torch.nn_module'}}
func @f()
builtin.func @f()
} : !torch.nn.Module<"c">
// -----
@ -33,7 +33,7 @@ torch.class_type @c {
torch.class_type @c {
// expected-error @+1 {{'builtin.func' op is not allowed inside `torch.class_type`}}
func @f()
builtin.func @f()
}
// -----
@ -60,7 +60,7 @@ torch.class_type @c {
torch.method "f", @f
}
func @f(%arg0: !torch.nn.Module<"c">) {
builtin.func @f(%arg0: !torch.nn.Module<"c">) {
return
}
@ -71,11 +71,11 @@ torch.class_type @c {
torch.method "f", @f
}
func private @f(%arg0: !torch.nn.Module<"c">)
builtin.func private @f(%arg0: !torch.nn.Module<"c">)
// -----
func private @f() {
builtin.func private @f() {
return
}
torch.class_type @c {
@ -85,7 +85,7 @@ torch.class_type @c {
// -----
func private @f(!torch.nn.Module<"other_c">) {
builtin.func private @f(!torch.nn.Module<"other_c">) {
return
}
torch.class_type @c {
@ -101,21 +101,21 @@ torch.class_type @c {
// -----
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
builtin.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
// -----
// expected-error @+1 {{'torch.type_bound' must be TypeAttr}}
func @f(%arg0: i32 {torch.type_bound = 1})
builtin.func @f(%arg0: i32 {torch.type_bound = 1})
// -----
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
func @f(%arg0: i32 {torch.type_bound = i32})
builtin.func @f(%arg0: i32 {torch.type_bound = i32})
// -----
func @derefine(%arg0: !torch.optional<!torch.tensor>) -> !torch.tensor {
builtin.func @derefine(%arg0: !torch.optional<!torch.tensor>) -> !torch.tensor {
// expected-error @+1 {{operand type '!torch.optional<!torch.tensor>' and result type '!torch.tensor' are cast incompatible}}
%0 = torch.derefine %arg0 : !torch.optional<!torch.tensor> to !torch.tensor
return %0 : !torch.tensor
@ -123,7 +123,7 @@ func @derefine(%arg0: !torch.optional<!torch.tensor>) -> !torch.tensor {
// -----
func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<!torch.tensor> {
builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<!torch.tensor> {
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<!torch.tensor>' are cast incompatible}}
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<!torch.tensor>
return %0 : !torch.optional<!torch.tensor>
@ -132,11 +132,11 @@ func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.op
// -----
// expected-error @+1 {{invalid dtype 'tuple<>' for !torch.tensor type}}
func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
builtin.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
// -----
func @torch.tensor() {
builtin.func @torch.tensor() {
// Incompatible shape.
// expected-error@+1 {{incompatible}}
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
@ -145,7 +145,7 @@ func @torch.tensor() {
// -----
func @torch.tensor() {
builtin.func @torch.tensor() {
// Incompatible dtype.
// expected-error@+1 {{incompatible}}
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
@ -154,7 +154,7 @@ func @torch.tensor() {
// -----
func @torch.tensor() {
builtin.func @torch.tensor() {
// Incompatible type.
// expected-error@+1 {{incompatible}}
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
@ -163,7 +163,7 @@ func @torch.tensor() {
// -----
func @torch.prim.ListConstruct() {
builtin.func @torch.prim.ListConstruct() {
%int2 = torch.constant.int 2
// expected-error@+1 {{operand types should have the same type as the list contained type}}
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<!torch.tensor>

View File

@ -2,7 +2,7 @@
// -----
// CHECK-LABEL: builtin.func @prim.if$branch_merge_type_tensor(
// CHECK-LABEL: func @prim.if$branch_merge_type_tensor(
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T1:.*]]: !torch.tensor,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor) -> !torch.bool {
@ -33,7 +33,7 @@ func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %
// -----
// CHECK-LABEL: builtin.func @prim.if$branch_merge_type_optional(
// CHECK-LABEL: func @prim.if$branch_merge_type_optional(
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional<!torch.tensor> {
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<!torch.tensor>) {
@ -60,7 +60,7 @@ func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor)
// -----
// CHECK-LABEL: builtin.func @prim.loop$region_arg_to_internal(
// CHECK-LABEL: func @prim.loop$region_arg_to_internal(
// CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional<!torch.tensor> {
// CHECK: %[[INT10:.*]] = torch.constant.int 10
// CHECK: %[[INDV:.*]] = torch.constant.int 0

View File

@ -5,7 +5,7 @@
// CHECK: %[[SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor<[2,3,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[SHAPED]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
return %0 : !torch.vtensor
}
@ -18,7 +18,7 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[CASTED]] : !torch.tensor<[2,3,?],f32>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[NONVAL_TENSOR]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
// CHECK: return %[[ERASED]] : !torch.tensor
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
return %1 : !torch.tensor
@ -31,7 +31,7 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[SHAPED:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[SHAPED]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
@ -44,7 +44,7 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[?,?],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
@ -58,7 +58,7 @@ func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !
// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[INPUT]], %[[WEIGHT]], %[[BIAS]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<[?,?],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
func @f(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
builtin.func @f(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
@ -69,7 +69,7 @@ func @f(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg
// CHECK: %[[CONV2D:.*]] = torch.aten.conv2d{{.*}} -> !torch.vtensor<[?,?,?,?],unk>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[CONV2D]] : !torch.vtensor<[?,?,?,?],unk> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!torch.vtensor {
builtin.func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
@ -83,7 +83,7 @@ func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!tor
// CHECK: %[[CONV2D:.*]] = torch.aten.conv2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[CONV2D]] : !torch.vtensor<[?,?,?,?],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:!torch.vtensor<*,f32>) ->!torch.vtensor {
builtin.func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:!torch.vtensor<*,f32>) ->!torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
@ -96,7 +96,7 @@ func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:!torch.v
// -----
// CHECK-LABEL: func @f
func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
builtin.func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
@ -113,7 +113,7 @@ func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
// -----
// CHECK-LABEL: func @f
func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
builtin.func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.aten.adaptive_avg_pool2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32>
@ -128,7 +128,7 @@ func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
// CHECK: %[[FLATTENED:.*]] = torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[FLATTENED]] : !torch.tensor<[?],f32> to !torch.tensor
// CHECK: return %[[SHAPE_ERASED]]
func @flatten_all(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
builtin.func @flatten_all(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
%end = torch.constant.int -1
%start = torch.constant.int 0
%0 = torch.aten.flatten.using_ints %arg0, %start, %end : !torch.tensor<[3,2,?,5],f32>, !torch.int, !torch.int -> !torch.tensor
@ -137,7 +137,7 @@ func @flatten_all(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
// CHECK-LABEL: func @flatten_some(
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[3,?,5],f32>
func @flatten_some(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
builtin.func @flatten_some(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
%end = torch.constant.int -2
%start = torch.constant.int 1
%0 = torch.aten.flatten.using_ints %arg0, %start, %end : !torch.tensor<[3,2,?,5],f32>, !torch.int, !torch.int -> !torch.tensor
@ -146,7 +146,7 @@ func @flatten_some(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
// CHECK-LABEL: func @flatten_rank0(
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[1],f32>
func @flatten_rank0(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
builtin.func @flatten_rank0(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
%end = torch.constant.int -1
%start = torch.constant.int 0
%0 = torch.aten.flatten.using_ints %arg0, %start, %end : !torch.tensor<[],f32>, !torch.int, !torch.int -> !torch.tensor
@ -157,7 +157,7 @@ func @flatten_rank0(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
// CHECK-LABEL: func @torch.aten.unsqueeze$basic(
// CHECK: torch.aten.unsqueeze {{.*}} -> !torch.tensor<[1],f32>
func @torch.aten.unsqueeze$basic(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
builtin.func @torch.aten.unsqueeze$basic(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
%int0 = torch.constant.int 0
%0 = torch.aten.unsqueeze %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
@ -165,7 +165,7 @@ func @torch.aten.unsqueeze$basic(%arg0: !torch.tensor<[],f32>) -> !torch.tensor
// CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative(
// CHECK: torch.aten.unsqueeze {{.*}} -> !torch.tensor<[1],f32>
func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
builtin.func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
%int-1 = torch.constant.int -1
%0 = torch.aten.unsqueeze %arg0, %int-1 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
@ -173,7 +173,7 @@ func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.tensor<[],f32>) -> !torc
// CHECK-LABEL: func @torch.aten.unsqueeze$invalid(
// CHECK: torch.aten.unsqueeze {{.*}} !torch.tensor<*,f32>
func @torch.aten.unsqueeze$invalid(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
builtin.func @torch.aten.unsqueeze$invalid(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
%int1 = torch.constant.int 1
%0 = torch.aten.unsqueeze %arg0, %int1 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
@ -181,7 +181,7 @@ func @torch.aten.unsqueeze$invalid(%arg0: !torch.tensor<[],f32>) -> !torch.tenso
// CHECK-LABEL: func @torch.aten.unsqueeze$invalid_negative(
// CHECK: torch.aten.unsqueeze {{.*}} -> !torch.tensor<*,f32>
func @torch.aten.unsqueeze$invalid_negative(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
builtin.func @torch.aten.unsqueeze$invalid_negative(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
%int-2 = torch.constant.int -2
%0 = torch.aten.unsqueeze %arg0, %int-2 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
@ -189,7 +189,7 @@ func @torch.aten.unsqueeze$invalid_negative(%arg0: !torch.tensor<[],f32>) -> !to
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front(
// CHECK: torch.aten.unsqueeze {{.*}} -> !torch.tensor<[1,2,3,4],f32>
func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
builtin.func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
%int0 = torch.constant.int 0
%0 = torch.aten.unsqueeze %arg0, %int0 : !torch.tensor<[2,3,4],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
@ -197,7 +197,7 @@ func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.tensor<[2,3,4],f32>)
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back(
// CHECK: torch.aten.unsqueeze {{.*}} -> !torch.tensor<[2,3,4,1],f32>
func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
builtin.func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
%int-1 = torch.constant.int -1
%0 = torch.aten.unsqueeze %arg0, %int-1 : !torch.tensor<[2,3,4],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
@ -205,7 +205,7 @@ func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.tensor<[2,3,4],f32>) -
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle(
// CHECK: torch.aten.unsqueeze {{.*}} -> !torch.tensor<[2,3,1,4],f32>
func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
builtin.func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
%int2 = torch.constant.int 2
%0 = torch.aten.unsqueeze %arg0, %int2 : !torch.tensor<[2,3,4],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
@ -213,7 +213,7 @@ func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.tensor<[2,3,4],f32>)
// CHECK-LABEL: func @torch.aten.unsqueeze$unknown_position(
// CHECK: torch.aten.unsqueeze {{.*}} -> !torch.tensor<*,f32>
func @torch.aten.unsqueeze$unknown_position(%arg0: !torch.tensor<[2],f32>, %arg1: !torch.int) -> !torch.tensor {
builtin.func @torch.aten.unsqueeze$unknown_position(%arg0: !torch.tensor<[2],f32>, %arg1: !torch.int) -> !torch.tensor {
%0 = torch.aten.unsqueeze %arg0, %arg1 : !torch.tensor<[2],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
}
@ -221,7 +221,7 @@ func @torch.aten.unsqueeze$unknown_position(%arg0: !torch.tensor<[2],f32>, %arg1
// -----
// CHECK-LABEL: func @f
func @f(%arg0: !torch.vtensor<[4,6,3],f32>, %arg1: !torch.vtensor<[1,1,3],f32>, %arg2: !torch.vtensor<[?,3],f32>) {
builtin.func @f(%arg0: !torch.vtensor<[4,6,3],f32>, %arg1: !torch.vtensor<[1,1,3],f32>, %arg2: !torch.vtensor<[?,3],f32>) {
%int1 = torch.constant.int 1
// CHECK: torch.aten.add{{.*}} -> !torch.vtensor<[?,?,?],f32>
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,6,3],f32>, !torch.vtensor<[1,1,3],f32>, !torch.int -> !torch.vtensor
@ -233,7 +233,7 @@ func @f(%arg0: !torch.vtensor<[4,6,3],f32>, %arg1: !torch.vtensor<[1,1,3],f32>,
// -----
// CHECK-LABEL: func @f
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// Check propagation through multiple ops.
// CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
// CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
@ -249,7 +249,7 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// Check rewriting logic in case of mixes of users that do/don't allow type
// refinement.
// CHECK-LABEL: func @f
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> (!torch.vtensor, !torch.vtensor) {
builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: %[[REFINED_TYPE:.*]] = torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor<[2,3,?],f32>
%1 = torch.aten.tanh %arg0 : !torch.vtensor<[2,3,?],f32> -> !torch.vtensor
// CHECK: %[[ORIGINAL_TYPE:.*]] = torch.tensor_static_info_cast %[[REFINED_TYPE]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
@ -265,7 +265,7 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> (!torch.vtensor, !torch.vtensor)
// CHECK: %[[ATEN:.*]] = torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<[2,3,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
br ^bb1(%cast: !torch.vtensor)
^bb1(%arg1: !torch.vtensor):
@ -278,13 +278,13 @@ func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// CHECK-LABEL: func @f
// CHECK: func private @callee
// CHECK-NEXT: torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<[2,3,?],f32>
func @f() {
module {
func private @callee(%arg0: !torch.vtensor) {
builtin.func @f() {
builtin.module {
builtin.func private @callee(%arg0: !torch.vtensor) {
%1 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
return
}
func @caller(%arg0: !torch.vtensor<[2,3,?],f32>) {
builtin.func @caller(%arg0: !torch.vtensor<[2,3,?],f32>) {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
call @callee(%cast) : (!torch.vtensor) -> ()
return
@ -295,14 +295,14 @@ func @f() {
// -----
// CHECK-LABEL: builtin.func @f(
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.tensor) -> !torch.bool {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[TENSOR]] : !torch.tensor to !torch.optional<!torch.tensor>
// CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[TENSOR]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool
// CHECK: return %[[RET]] : !torch.bool
func @f(%arg : !torch.tensor) -> !torch.bool {
builtin.func @f(%arg : !torch.tensor) -> !torch.bool {
%none = torch.constant.none
%optional = "torch.derefine"(%arg) : (!torch.tensor) -> !torch.optional<!torch.tensor>
%ret = "torch.aten.__isnot__"(%optional, %none) : (!torch.optional<!torch.tensor>, !torch.none) -> !torch.bool
@ -311,7 +311,7 @@ func @f(%arg : !torch.tensor) -> !torch.bool {
// -----
// CHECK-LABEL: builtin.func @aten.arange.start$int64_dtype(
// CHECK-LABEL: func @aten.arange.start$int64_dtype(
// CHECK-SAME: %[[START:.*]]: !torch.int,
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -322,7 +322,7 @@ func @f(%arg : !torch.tensor) -> !torch.bool {
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<[?],si64> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor
func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor {
builtin.func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor {
%none = torch.constant.none
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor
@ -330,7 +330,7 @@ func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !to
// -----
// CHECK-LABEL: builtin.func @aten.arange.start$float32_dtype(
// CHECK-LABEL: func @aten.arange.start$float32_dtype(
// CHECK-SAME: %[[START:.*]]: !torch.float,
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -341,7 +341,7 @@ func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !to
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor
func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor {
builtin.func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor {
%none = torch.constant.none
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor
@ -349,7 +349,7 @@ func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) ->
// -----
// CHECK-LABEL: builtin.func @aten.arange.start$specified_dtype(
// CHECK-LABEL: func @aten.arange.start$specified_dtype(
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -360,7 +360,7 @@ func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) ->
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor
func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
builtin.func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
%int6 = torch.constant.int 6
%none = torch.constant.none
%ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor
@ -369,7 +369,7 @@ func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
// -----
// CHECK-LABEL: builtin.func @aten.sum.dim_IntList(
// CHECK-LABEL: func @aten.sum.dim_IntList(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -383,7 +383,7 @@ func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[3],si64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.sum.dim_IntList(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor {
builtin.func @aten.sum.dim_IntList(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor {
%false = torch.constant.bool false
%none = torch.constant.none
%int0 = torch.constant.int 0
@ -395,7 +395,7 @@ func @aten.sum.dim_IntList(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor {
// -----
// CHECK-LABEL: builtin.func @aten.sum.dim_IntList$keepdim(
// CHECK-LABEL: func @aten.sum.dim_IntList$keepdim(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor {
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -411,7 +411,7 @@ func @aten.sum.dim_IntList(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor {
// CHECK-SAME: !torch.vtensor<[1,3,1],si64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.sum.dim_IntList$keepdim(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor {
builtin.func @aten.sum.dim_IntList$keepdim(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.vtensor {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
@ -422,7 +422,7 @@ func @aten.sum.dim_IntList$keepdim(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.v
}
// -----
// CHECK-LABEL: builtin.func @aten.sum.dim_IntList$unknown_position(
// CHECK-LABEL: func @aten.sum.dim_IntList$unknown_position(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],si64>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false
@ -433,7 +433,7 @@ func @aten.sum.dim_IntList$keepdim(%t: !torch.vtensor<[2,3,?],si64>) -> !torch.v
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[?],si64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.sum.dim_IntList$unknown_position(%t: !torch.vtensor<[2,3,?],si64>, %dim0: !torch.int) -> !torch.vtensor {
builtin.func @aten.sum.dim_IntList$unknown_position(%t: !torch.vtensor<[2,3,?],si64>, %dim0: !torch.int) -> !torch.vtensor {
%false = torch.constant.bool false
%none = torch.constant.none
%int-1 = torch.constant.int -1
@ -444,7 +444,7 @@ func @aten.sum.dim_IntList$unknown_position(%t: !torch.vtensor<[2,3,?],si64>, %d
// -----
// CHECK-LABEL: builtin.func @aten.any.dim(
// CHECK-LABEL: func @aten.any.dim(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
@ -452,7 +452,7 @@ func @aten.sum.dim_IntList$unknown_position(%t: !torch.vtensor<[2,3,?],si64>, %d
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[2,3],i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.any.dim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
builtin.func @aten.any.dim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<[2,3,?],i1>, !torch.int, !torch.bool -> !torch.vtensor
@ -461,7 +461,7 @@ func @aten.any.dim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
// -----
// CHECK-LABEL: builtin.func @aten.any.dim$keepdim(
// CHECK-LABEL: func @aten.any.dim$keepdim(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
@ -469,7 +469,7 @@ func @aten.any.dim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[2,3,1],i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.any.dim$keepdim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
builtin.func @aten.any.dim$keepdim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
%true = torch.constant.bool true
%int-1 = torch.constant.int -1
%ret = torch.aten.any.dim %t, %int-1, %true : !torch.vtensor<[2,3,?],i1>, !torch.int, !torch.bool -> !torch.vtensor
@ -478,7 +478,7 @@ func @aten.any.dim$keepdim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
// -----
// CHECK-LABEL: builtin.func @aten.any.dim$unknown_position(
// CHECK-LABEL: func @aten.any.dim$unknown_position(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],i1>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
@ -486,7 +486,7 @@ func @aten.any.dim$keepdim(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[?,?,?],i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.any.dim$unknown_position(%t: !torch.vtensor<[2,3,?],i1>, %dim: !torch.int) -> !torch.vtensor {
builtin.func @aten.any.dim$unknown_position(%t: !torch.vtensor<[2,3,?],i1>, %dim: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%ret = torch.aten.any.dim %t, %dim, %true : !torch.vtensor<[2,3,?],i1>, !torch.int, !torch.bool -> !torch.vtensor
return %ret : !torch.vtensor
@ -494,20 +494,20 @@ func @aten.any.dim$unknown_position(%t: !torch.vtensor<[2,3,?],i1>, %dim: !torch
// -----
// CHECK-LABEL: builtin.func @aten.any(
// CHECK-LABEL: func @aten.any(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
// CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<[2,3,?],i1> -> !torch.vtensor<[1],i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<[1],i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.any(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
builtin.func @aten.any(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
%ret = torch.aten.any %t: !torch.vtensor<[2,3,?],i1> -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: builtin.func @aten.transpose.int(
// CHECK-LABEL: func @aten.transpose.int(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
@ -515,7 +515,7 @@ func @aten.any(%t: !torch.vtensor<[2,3,?],i1>) -> !torch.vtensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,5,4,3],si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @aten.transpose.int(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
builtin.func @aten.transpose.int(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%ret = torch.aten.transpose.int %t, %int1, %int-1 : !torch.tensor<[2,3,4,5],si64>, !torch.int, !torch.int -> !torch.tensor
@ -524,7 +524,7 @@ func @aten.transpose.int(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
// -----
// CHECK-LABEL: builtin.func @aten.transpose.int$unknown_position(
// CHECK-LABEL: func @aten.transpose.int$unknown_position(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3,4,5],si64>,
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
@ -532,7 +532,7 @@ func @aten.transpose.int(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[?,?,?,?],si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @aten.transpose.int$unknown_position(%t: !torch.tensor<[2,3,4,5],si64>, %dim0: !torch.int) -> !torch.tensor {
builtin.func @aten.transpose.int$unknown_position(%t: !torch.tensor<[2,3,4,5],si64>, %dim0: !torch.int) -> !torch.tensor {
%int-1 = torch.constant.int -1
%ret = torch.aten.transpose.int %t, %dim0, %int-1 : !torch.tensor<[2,3,4,5],si64>, !torch.int, !torch.int -> !torch.tensor
return %ret: !torch.tensor
@ -540,7 +540,7 @@ func @aten.transpose.int$unknown_position(%t: !torch.tensor<[2,3,4,5],si64>, %di
// -----
// CHECK-LABEL: builtin.func @aten.view(
// CHECK-LABEL: func @aten.view(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
@ -551,7 +551,7 @@ func @aten.transpose.int$unknown_position(%t: !torch.tensor<[2,3,4,5],si64>, %di
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,?],si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @aten.view(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
builtin.func @aten.view(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
%int2 = torch.constant.int 2
%int-1 = torch.constant.int -1
%sizes = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
@ -561,7 +561,7 @@ func @aten.view(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
// -----
// CHECK-LABEL: builtin.func @prim.if$refined_type_conflicting(
// CHECK-LABEL: func @prim.if$refined_type_conflicting(
// CHECK-SAME: %[[NONE:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.tensor>
// CHECK: %[[NOT_NONE:.*]] = torch.aten.__isnot__ %[[NONE]], %[[NONE]] : !torch.none, !torch.none -> !torch.bool
@ -574,7 +574,7 @@ func @aten.view(%t: !torch.tensor<[2,3,4,5],si64>) -> !torch.tensor {
// CHECK: }
// CHECK: return %[[PRED:.*]] : !torch.tensor
func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
builtin.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
%optional = torch.derefine %none: !torch.none to !torch.optional<!torch.tensor>
%pred = torch.aten.__isnot__ %optional, %none : !torch.optional<!torch.tensor>, !torch.none -> !torch.bool
%res = torch.prim.If %pred -> (!torch.tensor) {
@ -589,7 +589,7 @@ func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
// ----
// CHECK-LABEL: builtin.func @torch.aten.tensor.float(
// CHECK-LABEL: func @torch.aten.tensor.float(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -597,7 +597,7 @@ func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[1],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
%none = torch.constant.none
%false = torch.constant.bool false
%ret = "torch.aten.tensor.float"(%t, %none, %none, %false) : (!torch.float, !torch.none, !torch.none, !torch.bool) -> !torch.tensor
@ -606,7 +606,7 @@ func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
// ----
// CHECK-LABEL: builtin.func @torch.aten.tensor.float$specified_dtype(
// CHECK-LABEL: func @torch.aten.tensor.float$specified_dtype(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CST11:.*]] = torch.constant.int 11
@ -615,7 +615,7 @@ func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[1],i1> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {
builtin.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {
%none = torch.constant.none
%int11 = torch.constant.int 11
%false = torch.constant.bool false
@ -625,7 +625,7 @@ func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor
// ----
// CHECK-LABEL: builtin.func @torch.aten.tensor(
// CHECK-LABEL: func @torch.aten.tensor(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -635,7 +635,7 @@ func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor(%t: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
builtin.func @torch.aten.tensor(%t: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
%none = torch.constant.none
%false = torch.constant.bool false
%ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<!torch.list<!torch.float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
@ -643,7 +643,7 @@ func @torch.aten.tensor(%t: !torch.list<!torch.list<!torch.float>>) -> !torch.te
}
// ----
// CHECK-LABEL: builtin.func @torch.aten.tensor$empty_list() -> !torch.tensor {
// CHECK-LABEL: func @torch.aten.tensor$empty_list() -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[DATA:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.float>
@ -651,7 +651,7 @@ func @torch.aten.tensor(%t: !torch.list<!torch.list<!torch.float>>) -> !torch.te
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor$empty_list() -> !torch.tensor {
builtin.func @torch.aten.tensor$empty_list() -> !torch.tensor {
%none = torch.constant.none
%false = torch.constant.bool false
%data = torch.prim.ListConstruct : () -> !torch.list<!torch.float>
@ -661,7 +661,7 @@ func @torch.aten.tensor$empty_list() -> !torch.tensor {
// ----
// CHECK-LABEL: builtin.func @torch.aten.tensor$specified_dtype(
// CHECK-LABEL: func @torch.aten.tensor$specified_dtype(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT4:.*]] = torch.constant.int 4
@ -670,7 +670,7 @@ func @torch.aten.tensor$empty_list() -> !torch.tensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[?,?],si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor$specified_dtype(%t: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
builtin.func @torch.aten.tensor$specified_dtype(%t: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
%none = torch.constant.none
%int4 = torch.constant.int 4
%false = torch.constant.bool false
@ -680,7 +680,7 @@ func @torch.aten.tensor$specified_dtype(%t: !torch.list<!torch.list<!torch.float
// ----
// CHECK-LABEL: builtin.func @torch.aten.zeros(
// CHECK-LABEL: func @torch.aten.zeros(
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -689,7 +689,7 @@ func @torch.aten.tensor$specified_dtype(%t: !torch.list<!torch.list<!torch.float
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<[?,2],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
builtin.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
%none = torch.constant.none
%int2 = torch.constant.int 2
%sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
@ -699,7 +699,7 @@ func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
// ----
// CHECK-LABEL: builtin.func @torch.aten.index_select(
// CHECK-LABEL: func @torch.aten.index_select(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2],si64>) -> !torch.tensor {
// CHECK: %[[DIM:.*]] = torch.constant.int 1
@ -707,7 +707,7 @@ func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,2,4],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.index_select(%input: !torch.tensor<[2,3,4], f32>, %index: !torch.tensor<[2], si64>) -> !torch.tensor {
builtin.func @torch.aten.index_select(%input: !torch.tensor<[2,3,4], f32>, %index: !torch.tensor<[2], si64>) -> !torch.tensor {
%dim = torch.constant.int 1
%ret = torch.aten.index_select %input, %dim, %index : !torch.tensor<[2,3,4], f32>, !torch.int, !torch.tensor<[2], si64> -> !torch.tensor
return %ret : !torch.tensor
@ -715,7 +715,7 @@ func @torch.aten.index_select(%input: !torch.tensor<[2,3,4], f32>, %index: !torc
// ----
// CHECK-LABEL: builtin.func @torch.aten.index_select$unknown_indexes(
// CHECK-LABEL: func @torch.aten.index_select$unknown_indexes(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[?],si64>) -> !torch.tensor {
// CHECK: %[[DIM:.*]] = torch.constant.int 1
@ -723,7 +723,7 @@ func @torch.aten.index_select(%input: !torch.tensor<[2,3,4], f32>, %index: !torc
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,?,4],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.index_select$unknown_indexes(%input: !torch.tensor<[2,3,4], f32>, %index: !torch.tensor<[?], si64>) -> !torch.tensor {
builtin.func @torch.aten.index_select$unknown_indexes(%input: !torch.tensor<[2,3,4], f32>, %index: !torch.tensor<[?], si64>) -> !torch.tensor {
%dim = torch.constant.int 1
%ret = torch.aten.index_select %input, %dim, %index : !torch.tensor<[2,3,4], f32>, !torch.int, !torch.tensor<[?], si64> -> !torch.tensor
return %ret : !torch.tensor
@ -731,7 +731,7 @@ func @torch.aten.index_select$unknown_indexes(%input: !torch.tensor<[2,3,4], f32
// ----
// CHECK-LABEL: builtin.func @torch.aten.index_select$unknown_dim(
// CHECK-LABEL: func @torch.aten.index_select$unknown_dim(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int,
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[?],si64>) -> !torch.tensor {
@ -739,14 +739,14 @@ func @torch.aten.index_select$unknown_indexes(%input: !torch.tensor<[2,3,4], f32
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.index_select$unknown_dim(%input: !torch.tensor<[2,3,4], f32>, %dim: !torch.int, %index: !torch.tensor<[?], si64>) -> !torch.tensor {
builtin.func @torch.aten.index_select$unknown_dim(%input: !torch.tensor<[2,3,4], f32>, %dim: !torch.int, %index: !torch.tensor<[?], si64>) -> !torch.tensor {
%ret = torch.aten.index_select %input, %dim, %index : !torch.tensor<[2,3,4], f32>, !torch.int, !torch.tensor<[?], si64> -> !torch.tensor
return %ret : !torch.tensor
}
// ----
// CHECK-LABEL: builtin.func @torch.aten.select.int(
// CHECK-LABEL: func @torch.aten.select.int(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
// CHECK-SAME: %[[INDEX:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DIM:.*]] = torch.constant.int 1
@ -754,28 +754,28 @@ func @torch.aten.index_select$unknown_dim(%input: !torch.tensor<[2,3,4], f32>, %
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,1,4],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.select.int(%input: !torch.tensor<[2,3,4], f32>, %index: !torch.int) -> !torch.tensor {
builtin.func @torch.aten.select.int(%input: !torch.tensor<[2,3,4], f32>, %index: !torch.int) -> !torch.tensor {
%dim = torch.constant.int 1
%ret = torch.aten.select.int %input, %dim, %index : !torch.tensor<[2,3,4], f32>, !torch.int, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// ----
// CHECK-LABEL: builtin.func @torch.aten.type_as(
// CHECK-LABEL: func @torch.aten.type_as(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<[?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor {
builtin.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor {
%ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor
return %ret: !torch.tensor
}
// ----
// CHECK-LABEL: builtin.func @torch.aten.gather(
// CHECK-LABEL: func @torch.aten.gather(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int,
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[1,2,3],si64>) -> !torch.tensor {
@ -784,14 +784,14 @@ func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[1,2,3],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.gather(%input: !torch.tensor<[2,3,4], f32>, %dim: !torch.int, %index: !torch.tensor<[1,2,3], si64>) -> !torch.tensor {
builtin.func @torch.aten.gather(%input: !torch.tensor<[2,3,4], f32>, %dim: !torch.int, %index: !torch.tensor<[1,2,3], si64>) -> !torch.tensor {
%false = torch.constant.bool false
%ret = torch.aten.gather %input, %dim, %index, %false : !torch.tensor<[2,3,4], f32>, !torch.int, !torch.tensor<[1,2,3], si64>, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}
// ----
// CHECK-LABEL: builtin.func @torch.aten.expand(
// CHECK-LABEL: func @torch.aten.expand(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
@ -802,7 +802,7 @@ func @torch.aten.gather(%input: !torch.tensor<[2,3,4], f32>, %dim: !torch.int, %
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,5,4],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.expand(%input: !torch.tensor<[2,1,4], f32>) -> !torch.tensor {
builtin.func @torch.aten.expand(%input: !torch.tensor<[2,1,4], f32>) -> !torch.tensor {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%int5 = torch.constant.int 5
@ -813,7 +813,7 @@ func @torch.aten.expand(%input: !torch.tensor<[2,1,4], f32>) -> !torch.tensor {
}
// ----
// CHECK-LABEL: builtin.func @torch.aten.expand$unknown_sizes(
// CHECK-LABEL: func @torch.aten.expand$unknown_sizes(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
// CHECK-SAME: %[[SIZEX:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -824,7 +824,7 @@ func @torch.aten.expand(%input: !torch.tensor<[2,1,4], f32>) -> !torch.tensor {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,?,4],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
// CHECK: }
func @torch.aten.expand$unknown_sizes(%input: !torch.tensor<[2,1,4], f32>, %index: !torch.int) -> !torch.tensor {
builtin.func @torch.aten.expand$unknown_sizes(%input: !torch.tensor<[2,1,4], f32>, %index: !torch.int) -> !torch.tensor {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%int4 = torch.constant.int 4
@ -834,7 +834,7 @@ func @torch.aten.expand$unknown_sizes(%input: !torch.tensor<[2,1,4], f32>, %inde
}
// ----
// CHECK-LABEL: builtin.func @torch.aten.repeat(
// CHECK-LABEL: func @torch.aten.repeat(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
// CHECK-SAME: %[[REPEATX:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -844,7 +844,7 @@ func @torch.aten.expand$unknown_sizes(%input: !torch.tensor<[2,1,4], f32>, %inde
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,?,16],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.repeat(%input: !torch.tensor<[2,1,4], f32>, %repeat: !torch.int) -> !torch.tensor {
builtin.func @torch.aten.repeat(%input: !torch.tensor<[2,1,4], f32>, %repeat: !torch.int) -> !torch.tensor {
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%repeats = torch.prim.ListConstruct %int1, %repeat, %int4: (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
@ -854,7 +854,7 @@ func @torch.aten.repeat(%input: !torch.tensor<[2,1,4], f32>, %repeat: !torch.int
// ----
// CHECK-LABEL: builtin.func @torch.aten.cat(
// CHECK-LABEL: func @torch.aten.cat(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -863,7 +863,7 @@ func @torch.aten.repeat(%input: !torch.tensor<[2,1,4], f32>, %repeat: !torch.int
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,?,4],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
builtin.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
%int1 = torch.constant.int 1
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<!torch.tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<!torch.tensor>, !torch.int -> !torch.tensor
@ -871,7 +871,7 @@ func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4
}
// ----
// CHECK-LABEL: builtin.func @torch.aten.cat$unknown_dim(
// CHECK-LABEL: func @torch.aten.cat$unknown_dim(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
@ -880,38 +880,38 @@ func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.cat$unknown_dim(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>, %dim: !torch.int) -> !torch.tensor {
builtin.func @torch.aten.cat$unknown_dim(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>, %dim: !torch.int) -> !torch.tensor {
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<!torch.tensor>
%ret = torch.aten.cat %tensorList, %dim: !torch.list<!torch.tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// ----
// CHECK-LABEL: builtin.func @torch.aten._shape_as_tensor(
// CHECK-LABEL: func @torch.aten._shape_as_tensor(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<[3],si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[3],si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
// CHECK: }
func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor {
builtin.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor
return %ret : !torch.tensor
}
// ----
// CHECK-LABEL: builtin.func @torch.aten._shape_as_tensor$unknown_input_shape(
// CHECK-LABEL: func @torch.aten._shape_as_tensor$unknown_input_shape(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<[?],si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[?],si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
// CHECK: }
func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor {
builtin.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor
return %ret : !torch.tensor
}
// ----
// CHECK-LABEL: builtin.func @torch.aten.embedding(
// CHECK-LABEL: func @torch.aten.embedding(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -919,7 +919,7 @@ func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) ->
// CHECK: %[[RET:.*]] = torch.aten.embedding %[[INPUT]], %[[INDEXES]], %[[PADDING_IDX]], %[[FALSE]], %[[FALSE]] : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor<[2,3,512],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,3,512],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor {
builtin.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor