diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index bd6d324bd..363f23699 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5292,6 +5292,32 @@ def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [ }]; } +def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$src, + Torch_IntType:$dim, + Torch_IntType:$index + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSelectScatterOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenSelectScatterOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [ AllowsTypeRefinement, HasValueSemantics, @@ -5747,6 +5773,34 @@ def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ }]; } +def Torch_AtenSliceScatterOp : Torch_Op<"aten.slice_scatter", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$src, + Torch_IntType:$dim, + AnyTorchOptionalIntType:$start, + AnyTorchOptionalIntType:$end, + Torch_IntType:$step + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSliceScatterOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenSliceScatterOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenLenTensorOp : Torch_Op<"aten.len.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 7897dc194..8f2f33ad3 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -7,6 +7,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" @@ -29,6 +33,94 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +static Value toPositiveValidDim(ConversionPatternRewriter &rewriter, + Location loc, Value torchType, + Value builtinType, Value valueForNone, + Value dimSize) { + if (torchType.getType().isa()) + return valueForNone; + auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); + Value positiveDim = + toPositiveDimDynamic(rewriter, loc, builtinType, dimSizeAsInt); + // startOrEnd < 0 ? 0 : startOrEnd + Value cst0 = rewriter.create( + loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); + Value predDimSltZero = rewriter.create( + loc, arith::CmpIPredicate::slt, positiveDim, cst0); + Value startOrEndAtLeastZero = + rewriter.create(loc, predDimSltZero, cst0, positiveDim); + // startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd + Value startOrEndSgtDimSize = rewriter.create( + loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt); + Value startOrEndBoundedByDimSize = rewriter.create( + loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero); + + return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize); +} + +template +LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + SmallVector &resultShape, + SmallVector &offsets, + SmallVector &strides) { + Location loc = op.getLoc(); + auto input = adaptor.self(); + RankedTensorType inputType = + input.getType().template cast(); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return op->emitError("unimplemented: dim is not constant"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + Value dimSize = inputShape[dim]; + + Value torchTypeStart = op.start(); + Value torchTypeEnd = op.end(); + Value builtinTypeStart = adaptor.start(); + Value builtinTypeEnd = adaptor.end(); + + if (torchTypeStart.getType().isa() || + torchTypeEnd.getType().isa()) + return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); + + int64_t step; + if (!matchPattern(op.step(), m_TorchConstantInt(&step))) { + if (!op.step().getType().template isa()) + return op->emitError("unimplemented: step is not constant"); + step = 1; + } + + Value start = toPositiveValidDim(rewriter, loc, torchTypeStart, + builtinTypeStart, zero, dimSize); + Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd, + dimSize, dimSize); + + // end >= start ? end : start + Value endSgeStart = rewriter.create( + loc, arith::CmpIPredicate::sge, end, start); + end = rewriter.create(loc, endSgeStart, end, start); + Value stepIndex = rewriter.create(loc, step); + + // Slice logic: resultSize = floordiv(end - start + step - 1, step) + resultShape = getTensorSizes(rewriter, loc, input); + Value len = rewriter.create(loc, end, start); + Value resultSize = rewriter.create(loc, len, stepIndex); + resultSize = rewriter.create(loc, resultSize, one); + resultSize = rewriter.create(loc, resultSize, stepIndex); + resultShape[dim] = resultSize; + + strides.resize(inputType.getRank(), one); + offsets.resize(inputType.getRank(), zero); + + offsets[dim] = start; + strides[dim] = rewriter.create(loc, strides[dim], stepIndex); + return success(); +} namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { @@ -742,77 +834,19 @@ public: TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.self(); - RankedTensorType inputType = input.getType().cast(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - int64_t dim; - if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return op->emitError("unimplemented: dim is not constant"); - - SmallVector inputShape = getTensorSizes(rewriter, loc, input); - Value dimSize = inputShape[dim]; - - auto adjustStartOrEnd = [&](Value startOrEndTorchType, - Value startOrEndBuiltin, Value valueForNone) { - if (startOrEndTorchType.getType().isa()) - return valueForNone; - auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); - Value startOrEndToPositive = - toPositiveDimDynamic(rewriter, loc, startOrEndBuiltin, dimSizeAsInt); - // startOrEnd < 0 ? 0 : startOrEnd - Value cst0 = rewriter.create( - loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); - Value predDimSltZero = rewriter.create( - loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0); - Value startOrEndAtLeastZero = rewriter.create( - loc, predDimSltZero, cst0, startOrEndToPositive); - // startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd - Value startOrEndSgtDimSize = rewriter.create( - loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt); - Value startOrEndBoundedByDimSize = rewriter.create( - loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero); - - return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize); - }; - - if (op.start().getType().isa() || - op.end().getType().isa()) - return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); - Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero); - Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize); - - // end >= start ? end : start - Value endSgeStart = rewriter.create( - loc, arith::CmpIPredicate::sge, end, start); - end = rewriter.create(loc, endSgeStart, end, start); - - int64_t step; - if (!matchPattern(op.step(), m_TorchConstantInt(&step))) { - if (!op.step().getType().isa()) - return op->emitError("unimplemented: step is not constant"); - step = 1; + SmallVector resultShape; + SmallVector offsets; + SmallVector strides; + if (failed(prepareArgumentsForSlicingOp( + op, adaptor, rewriter, resultShape, offsets, strides))) { + return failure(); } - // Slice logic: resultSize = floordiv(end - start + step - 1, step) - Value stepIndex = rewriter.create(loc, step); - Value len = rewriter.create(loc, end, start); - Value resultSize = rewriter.create(loc, len, stepIndex); - resultSize = rewriter.create(loc, resultSize, one); - resultSize = - rewriter.create(loc, resultSize, stepIndex); - - SmallVector resultShape = getTensorSizes(rewriter, loc, input); - resultShape[dim] = resultSize; - - SmallVector offsets(inputType.getRank(), zero); - SmallVector strides(inputType.getRank(), one); - offsets[dim] = start; - strides[dim] = rewriter.create(loc, strides[dim], stepIndex); - Value result = rewriter.create( loc, input, offsets, resultShape, strides); @@ -1019,6 +1053,55 @@ public: }; } // namespace +namespace { +class ConvertAtenSliceScatterOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + TypeConverter *typeConverter = getTypeConverter(); + + auto input = adaptor.self(); + + RankedTensorType resultType = + typeConverter->convertType(op->getResult(0).getType()) + .cast(); + + SmallVector resultShape; + SmallVector offsets; + SmallVector strides; + if (failed(prepareArgumentsForSlicingOp( + op, adaptor, rewriter, resultShape, offsets, strides))) { + return failure(); + } + + Value src = adaptor.src(); + auto srcType = src.getType().cast(); + int64_t srcRank = srcType.getRank(); + SmallVector srcAbstractSizes(srcRank, kUnknownSize); + auto abstractSrcType = + RankedTensorType::get(srcAbstractSizes, srcType.getElementType()); + Value abstractSrc = + rewriter.create(loc, abstractSrcType, src); + + Value result = rewriter.create( + loc, abstractSrc, input, offsets, resultShape, strides); + + rewriter.replaceOpWithNewOp(op, resultType, result); + + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1047,4 +1130,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9ea6363db..94d8aa494 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -16,7 +16,9 @@ #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringExtras.h" +#include using namespace mlir; using namespace mlir::torch; @@ -2120,6 +2122,55 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose the `aten.select_scatter` operation into `aten.slice_scatter` op. +class DecomposeAtenSelectScatterOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSelectScatterOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value start = op.index(); + Value dim = op.dim(); + Value self = op.self(); + Value src = op.src(); + + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value startPlusOne = + rewriter.create(loc, one.getType(), start, one); + BaseTensorType srcTensorType = src.getType().cast(); + SmallVector sizes; + if (!srcTensorType.hasSizes()) + return rewriter.notifyMatchFailure(op, "src tensor must have size"); + + ArrayRef srcShape = srcTensorType.getSizes(); + // `src` has a reduced rank. Hence add 1. + int64_t srcRank = srcShape.size() + 1; + int64_t dimInt = 0; + if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, srcRank); + if (!isValidDim(dimInt, srcRank)) + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + + sizes.append(srcShape.begin(), srcShape.end()); + sizes.insert(sizes.begin() + dimInt, 1); + + } else { + sizes.resize(srcShape.size() + 1, kUnknownSize); + } + Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes), + srcTensorType.getDtype()); + src = rewriter.create(loc, srcType, src, dim); + rewriter.replaceOpWithNewOp( + op, op.self().getType(), self, src, dim, start, startPlusOne, + /*step=*/one); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -2271,6 +2322,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 58c7b5579..c5e956cf0 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -645,10 +645,11 @@ ChangeResult TypeAnalyzer::visitOperation( AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp, - AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp, - AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, - AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, - AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp, + AtenSelectScatterOp, AtenSliceTensorOp, AtenSliceScatterOp, + AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, + AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, + AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, + ValsemVariantAtenCopyOp, AtenZeroFunctionalOp, AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 637871d96..cb9cbf6b8 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -53,14 +53,14 @@ module { return %0 : !torch.list } func.func @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %int4 = torch.constant.int 4 - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %true = torch.constant.bool true %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { @@ -124,9 +124,9 @@ module { return %0 : !torch.list } func.func @__torch__.torch.jit._shape_functions.arange_end(%arg0: !torch.union, %arg1: !torch.any, %arg2: !torch.any, %arg3: !torch.any, %arg4: !torch.any) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.operator "aten.ge"(%arg0, %int0) : (!torch.union, !torch.int) -> !torch.bool torch.prim.If %0 -> () { torch.prim.If.yield @@ -140,9 +140,9 @@ module { return %3 : !torch.list } func.func @__torch__.torch.jit._shape_functions.arange_start(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.any, %arg3: !torch.any, %arg4: !torch.any, %arg5: !torch.any) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.operator "aten.ge"(%arg1, %int0) : (!torch.union, !torch.int) -> !torch.bool torch.prim.If %0 -> () { torch.prim.If.yield @@ -164,9 +164,9 @@ module { return %5 : !torch.list } func.func @__torch__.torch.jit._shape_functions.arange_start_step(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.union, %arg3: !torch.any, %arg4: !torch.any, %arg5: !torch.any, %arg6: !torch.any) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.operator "aten.ne"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool torch.prim.If %0 -> () { torch.prim.If.yield @@ -201,8 +201,8 @@ module { return %5 : !torch.list } func.func @__torch__.torch.jit._shape_functions.squeeze_nodim(%arg0: !torch.list) -> !torch.list { - %true = torch.constant.bool true %int1 = torch.constant.int 1 + %true = torch.constant.bool true %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int torch.prim.Loop %1, %true, init() { @@ -221,10 +221,10 @@ module { return %0 : !torch.list } func.func @__torch__.torch.jit._shape_functions.squeeze(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list { + %true = torch.constant.bool true %none = torch.constant.none %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 - %true = torch.constant.bool true %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int @@ -282,11 +282,11 @@ module { return %0 : !torch.list } func.func @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %true = torch.constant.bool true %0 = torch.aten.le.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool %1 = torch.prim.If %0 -> (!torch.int) { torch.prim.If %arg2 -> () { @@ -373,8 +373,8 @@ module { return %11 : !torch.list } func.func @__torch__.torch.jit._shape_functions.slice(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list { - %int9223372036854775807 = torch.constant.int 9223372036854775807 %true = torch.constant.bool true + %int9223372036854775807 = torch.constant.int 9223372036854775807 %none = torch.constant.none %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 @@ -507,11 +507,11 @@ module { return %int9223372036854775807 : !torch.int } func.func @__torch__.torch.jit._shape_functions.select(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { - %int1 = torch.constant.int 1 %true = torch.constant.bool true %none = torch.constant.none %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { @@ -581,10 +581,10 @@ module { return %16 : !torch.list } func.func @__torch__.torch.jit._shape_functions.index_select(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " %true = torch.constant.bool true %int1 = torch.constant.int 1 + %none = torch.constant.none + %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool @@ -663,8 +663,8 @@ module { return %16 : !torch.list } func.func @__torch__.torch.jit._shape_functions.multiply_integers(%arg0: !torch.list) -> !torch.int { - %true = torch.constant.bool true %int1 = torch.constant.int 1 + %true = torch.constant.bool true %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.prim.Loop %0, %true, init(%int1) { ^bb0(%arg1: !torch.int, %arg2: !torch.int): @@ -676,11 +676,11 @@ module { } func.func @__torch__.torch.jit._shape_functions.embedding(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list { %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { @@ -762,19 +762,19 @@ module { return %4 : !torch.list } func.func @__torch__.torch.jit._shape_functions.mm(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "AssertionError: mat2 must be a matrix" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: self must be a matrix" - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %str = torch.constant.str "AssertionError: self must be a matrix" + %none = torch.constant.none + %str_0 = torch.constant.str "AssertionError: mat2 must be a matrix" + %str_1 = torch.constant.str "AssertionError: " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int @@ -791,7 +791,7 @@ module { torch.prim.If %6 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } %7 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int @@ -800,11 +800,11 @@ module { return %9 : !torch.list } func.func @__torch__.torch.jit._shape_functions.dot(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool %2 = torch.prim.If %1 -> (!torch.bool) { @@ -833,12 +833,12 @@ module { return %6 : !torch.list } func.func @__torch__.torch.jit._shape_functions.mv(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %false = torch.constant.bool false + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool %2 = torch.prim.If %1 -> (!torch.bool) { @@ -872,15 +872,15 @@ module { %str_0 = torch.constant.str "AssertionError: mat2 must be a matrix" %str_1 = torch.constant.str "AssertionError: self must be a matrix" %str_2 = torch.constant.str "AssertionError: " - %none = torch.constant.none - %str_3 = torch.constant.str "AssertionError: both arguments to matmul need to be at least 1D" - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %int-2 = torch.constant.int -2 - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %int-2 = torch.constant.int -2 + %true = torch.constant.bool true + %int-1 = torch.constant.int -1 + %str_3 = torch.constant.str "AssertionError: both arguments to matmul need to be at least 1D" + %none = torch.constant.none %0 = torch.prim.Uninitialized : !torch.list %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int @@ -1204,13 +1204,13 @@ module { return %5 : !torch.list } func.func @__torch__.torch.jit._shape_functions.broadcast(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %false = torch.constant.bool false - %true = torch.constant.bool true - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %true = torch.constant.bool true + %false = torch.constant.bool false + %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" + %str_0 = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int @@ -1251,8 +1251,8 @@ module { torch.prim.If.yield %false : !torch.bool } torch.prim.If %16 -> () { - %20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str + %20 = torch.aten.format(%str, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %21 = torch.aten.add.str %str_0, %20 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %21, %none : !torch.str, !torch.none torch.prim.If.yield } else { @@ -1270,25 +1270,25 @@ module { return %3 : !torch.list } func.func @__torch__.torch.jit._shape_functions.linear(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>) -> !torch.list { + %none = torch.constant.none %str = torch.constant.str "AssertionError: both arguments to matmul need to be at least 1D" %int-1 = torch.constant.int -1 %true = torch.constant.bool true %int-2 = torch.constant.int -2 %false = torch.constant.bool false - %str_0 = torch.constant.str "AssertionError: self must be a matrix" - %str_1 = torch.constant.str "AssertionError: mat2 must be a matrix" - %str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %str_3 = torch.constant.str "AssertionError: " - %none = torch.constant.none %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %str_0 = torch.constant.str "AssertionError: " + %str_1 = torch.constant.str "AssertionError: self must be a matrix" + %str_2 = torch.constant.str "AssertionError: mat2 must be a matrix" + %str_3 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int @@ -1334,7 +1334,7 @@ module { torch.prim.If %15 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %16 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int @@ -1343,7 +1343,7 @@ module { torch.prim.If %18 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } torch.prim.If.yield %5 : !torch.list @@ -1368,7 +1368,7 @@ module { torch.prim.If %18 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %19 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int @@ -1377,7 +1377,7 @@ module { torch.prim.If %21 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %22 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int @@ -1413,7 +1413,7 @@ module { torch.prim.If %27 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %28 = torch.prim.ListConstruct : () -> !torch.list @@ -1430,7 +1430,7 @@ module { torch.prim.If %31 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } %32 = torch.aten.len.t %4 : !torch.list -> !torch.int @@ -1438,7 +1438,7 @@ module { torch.prim.If %33 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none torch.prim.If.yield } %34 = torch.aten.__getitem__.t %28, %int1 : !torch.list, !torch.int -> !torch.int @@ -1447,7 +1447,7 @@ module { torch.prim.If %36 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %37 = torch.aten.__getitem__.t %28, %int0 : !torch.list, !torch.int -> !torch.int @@ -1490,7 +1490,7 @@ module { torch.prim.If %23 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } %24 = torch.aten.len.t %4 : !torch.list -> !torch.int @@ -1498,7 +1498,7 @@ module { torch.prim.If %25 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none torch.prim.If.yield } %26 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int @@ -1507,7 +1507,7 @@ module { torch.prim.If %28 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %29 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int @@ -1587,8 +1587,8 @@ module { torch.prim.If.yield %false : !torch.bool } torch.prim.If %50 -> () { - %54 = torch.aten.format(%str_2, %45, %47, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %55 = torch.aten.add.str %str_3, %54 : !torch.str, !torch.str -> !torch.str + %54 = torch.aten.format(%str_3, %45, %47, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %55 = torch.aten.add.str %str_0, %54 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %55, %none : !torch.str, !torch.none torch.prim.If.yield } else { @@ -1673,8 +1673,8 @@ module { torch.prim.If.yield %false : !torch.bool } torch.prim.If %31 -> () { - %35 = torch.aten.format(%str_2, %26, %28, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %36 = torch.aten.add.str %str_3, %35 : !torch.str, !torch.str -> !torch.str + %35 = torch.aten.format(%str_3, %26, %28, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %36 = torch.aten.add.str %str_0, %35 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %36, %none : !torch.str, !torch.none torch.prim.If.yield } else { @@ -1693,7 +1693,7 @@ module { torch.prim.If %18 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } torch.prim.If.yield @@ -1703,11 +1703,11 @@ module { return %11 : !torch.list } func.func @__torch__.torch.jit._shape_functions.t(%arg0: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int2 = torch.constant.int 2 - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { @@ -1738,24 +1738,24 @@ module { return %4 : !torch.list } func.func @__torch__.torch.jit._shape_functions.max_pool2d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list { - %false = torch.constant.bool false - %str = torch.constant.str "AssertionError: stride should not be zeero" - %int-1 = torch.constant.int -1 - %int-2 = torch.constant.int -2 - %int-3 = torch.constant.int -3 - %int-4 = torch.constant.int -4 - %str_0 = torch.constant.str "AssertionError: " - %str_1 = torch.constant.str "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints" - %str_2 = torch.constant.str "AssertionError: max_pool2d: padding must be either be a single int, or a tuple of two ints" - %str_3 = torch.constant.str "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" - %none = torch.constant.none - %str_4 = torch.constant.str "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints" %true = torch.constant.bool true + %none = torch.constant.none + %str = torch.constant.str "AssertionError: " + %false = torch.constant.bool false + %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 - %int0 = torch.constant.int 0 %int3 = torch.constant.int 3 %int4 = torch.constant.int 4 + %str_0 = torch.constant.str "AssertionError: stride should not be zeero" + %str_1 = torch.constant.str "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints" + %str_2 = torch.constant.str "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + %str_3 = torch.constant.str "AssertionError: max_pool2d: padding must be either be a single int, or a tuple of two ints" + %str_4 = torch.constant.str "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints" + %int-4 = torch.constant.int -4 + %int-3 = torch.constant.int -3 + %int-2 = torch.constant.int -2 + %int-1 = torch.constant.int -1 %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool %2 = torch.prim.If %1 -> (!torch.bool) { @@ -1768,7 +1768,7 @@ module { torch.prim.If %2 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int @@ -1799,7 +1799,7 @@ module { torch.prim.If %10 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none torch.prim.If.yield } %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int @@ -1837,7 +1837,7 @@ module { torch.prim.If %19 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none torch.prim.If.yield } %20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int @@ -1861,7 +1861,7 @@ module { torch.prim.If %26 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none torch.prim.If.yield } %27 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int @@ -1885,7 +1885,7 @@ module { torch.prim.If %33 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %34 = torch.aten.len.t %arg0 : !torch.list -> !torch.int @@ -1903,7 +1903,7 @@ module { torch.prim.If %40 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %41 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int @@ -1938,7 +1938,7 @@ module { torch.prim.If %52 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %53 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int @@ -1980,7 +1980,7 @@ module { torch.prim.If %66 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %67 = torch.aten.gt.int %16, %int0 : !torch.int, !torch.int -> !torch.bool @@ -1993,7 +1993,7 @@ module { torch.prim.If %68 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %69 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool @@ -2006,7 +2006,7 @@ module { torch.prim.If %70 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %71 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int @@ -2052,7 +2052,7 @@ module { torch.prim.If %77 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %78 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int @@ -2067,7 +2067,7 @@ module { torch.prim.If %80 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %81 = torch.aten.ge.int %63, %int1 : !torch.int, !torch.int -> !torch.bool @@ -2080,7 +2080,7 @@ module { torch.prim.If %82 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %83 = torch.aten.len.t %arg0 : !torch.list -> !torch.int @@ -2095,9 +2095,9 @@ module { return %85 : !torch.list } func.func @__torch__.torch.jit._shape_functions.pooling_output_shape(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.bool) -> !torch.int { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: stride should not be zeero" %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: stride should not be zeero" + %none = torch.constant.none %0 = torch.aten.ne.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool torch.prim.If %0 -> () { torch.prim.If.yield @@ -2109,8 +2109,8 @@ module { return %1 : !torch.int } func.func @__torch__.torch.jit._shape_functions.pooling_output_shape_pad_lr(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.int { - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 %0 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int %1 = torch.aten.add.int %0, %arg3 : !torch.int, !torch.int -> !torch.int %2 = torch.aten.sub.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int @@ -2148,15 +2148,15 @@ module { return %0 : !torch.int } func.func @__torch__.torch.jit._shape_functions.pool2d_shape_check(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.int) -> !torch.none { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %true = torch.constant.bool true %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.gt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool %2 = torch.prim.If %1 -> (!torch.bool) { @@ -2274,24 +2274,24 @@ module { return %none : !torch.none } func.func @__torch__.torch.jit._shape_functions.max_pool2d_with_indices(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple, list> { - %false = torch.constant.bool false - %str = torch.constant.str "AssertionError: stride should not be zeero" - %int4 = torch.constant.int 4 - %int3 = torch.constant.int 3 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %str_0 = torch.constant.str "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" - %str_2 = torch.constant.str "AssertionError: max_pool2d: padding must be either be a single int, or a tuple of two ints" - %str_3 = torch.constant.str "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints" - %str_4 = torch.constant.str "AssertionError: " - %int-4 = torch.constant.int -4 - %int-3 = torch.constant.int -3 - %int-2 = torch.constant.int -2 %int-1 = torch.constant.int -1 + %int-2 = torch.constant.int -2 + %int-3 = torch.constant.int -3 + %int-4 = torch.constant.int -4 + %str = torch.constant.str "AssertionError: " + %str_0 = torch.constant.str "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints" + %str_1 = torch.constant.str "AssertionError: max_pool2d: padding must be either be a single int, or a tuple of two ints" + %str_2 = torch.constant.str "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + %none = torch.constant.none + %str_3 = torch.constant.str "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints" + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %str_4 = torch.constant.str "AssertionError: stride should not be zeero" + %false = torch.constant.bool false %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool %2 = torch.prim.If %1 -> (!torch.bool) { @@ -2304,7 +2304,7 @@ module { torch.prim.If %2 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none torch.prim.If.yield } %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int @@ -2335,7 +2335,7 @@ module { torch.prim.If %10 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none torch.prim.If.yield } %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int @@ -2373,7 +2373,7 @@ module { torch.prim.If %19 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } %20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int @@ -2397,7 +2397,7 @@ module { torch.prim.If %26 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %27 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int @@ -2421,7 +2421,7 @@ module { torch.prim.If %33 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %34 = torch.aten.len.t %arg0 : !torch.list -> !torch.int @@ -2439,7 +2439,7 @@ module { torch.prim.If %40 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none torch.prim.If.yield } %41 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int @@ -2474,7 +2474,7 @@ module { torch.prim.If %52 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none torch.prim.If.yield } %53 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int @@ -2516,7 +2516,7 @@ module { torch.prim.If %66 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %67 = torch.aten.gt.int %16, %int0 : !torch.int, !torch.int -> !torch.bool @@ -2529,7 +2529,7 @@ module { torch.prim.If %68 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %69 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool @@ -2542,7 +2542,7 @@ module { torch.prim.If %70 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %71 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int @@ -2588,7 +2588,7 @@ module { torch.prim.If %77 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %78 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int @@ -2603,7 +2603,7 @@ module { torch.prim.If %80 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %81 = torch.aten.ge.int %63, %int1 : !torch.int, !torch.int -> !torch.bool @@ -2616,7 +2616,7 @@ module { torch.prim.If %82 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %83 = torch.aten.len.t %arg0 : !torch.list -> !torch.int @@ -2632,11 +2632,11 @@ module { return %86 : !torch.tuple, list> } func.func @__torch__.torch.jit._shape_functions.transpose(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { + %true = torch.constant.bool true %none = torch.constant.none %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 - %true = torch.constant.bool true %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool %2 = torch.prim.If %1 -> (!torch.int) { @@ -2740,9 +2740,9 @@ module { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 - %false = torch.constant.bool false - %none = torch.constant.none %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %false = torch.constant.bool false %int3 = torch.constant.int 3 %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool @@ -2920,10 +2920,10 @@ module { return %28 : !torch.list } func.func @__torch__.torch.jit._shape_functions.conv_output_size(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list { - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %true = torch.constant.bool true %0 = call @__torch__.torch.jit._shape_functions.check_shape_forward(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.none %1 = torch.aten.len.t %arg5 : !torch.list -> !torch.int %2 = torch.aten.gt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool @@ -2964,13 +2964,13 @@ module { return %4 : !torch.list } func.func @__torch__.torch.jit._shape_functions.check_shape_forward(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.none { - %false = torch.constant.bool false - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %true = torch.constant.bool true + %false = torch.constant.bool false %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %2 = call @__torch__.torch.jit._shape_functions.check_non_negative(%arg4) : (!torch.list) -> !torch.bool @@ -3073,9 +3073,9 @@ module { return %none : !torch.none } func.func @__torch__.torch.jit._shape_functions.check_non_negative(%arg0: !torch.list) -> !torch.bool { - %true = torch.constant.bool true - %false = torch.constant.bool false %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.prim.Loop %0, %true, init(%false) { ^bb0(%arg1: !torch.int, %arg2: !torch.bool): @@ -3095,9 +3095,9 @@ module { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 - %false = torch.constant.bool false - %none = torch.constant.none %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %false = torch.constant.bool false %int4 = torch.constant.int 4 %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool @@ -3291,9 +3291,9 @@ module { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 - %false = torch.constant.bool false - %none = torch.constant.none %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %false = torch.constant.bool false %int5 = torch.constant.int 5 %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int5 : !torch.int, !torch.int -> !torch.bool @@ -3471,8 +3471,8 @@ module { return %28 : !torch.list } func.func @__torch__.torch.jit._shape_functions.conv_backwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>) -> !torch.tuple, list, list> { - %int1 = torch.constant.int 1 %true = torch.constant.bool true + %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int torch.prim.Loop %1, %true, init() { @@ -3495,9 +3495,9 @@ module { return %6 : !torch.tuple, list, list> } func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { + %true = torch.constant.bool true %none = torch.constant.none %str = torch.constant.str "AssertionError: " - %true = torch.constant.bool true %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int @@ -3622,11 +3622,11 @@ module { %str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension" %str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions" %false = torch.constant.bool false - %int1 = torch.constant.int 1 %true = torch.constant.bool true %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: " + %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %str_1 = torch.constant.str "AssertionError: " %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int torch.prim.Loop %0, %true, init() { ^bb0(%arg2: !torch.int): @@ -3824,10 +3824,10 @@ module { return %12 : !torch.list } func.func @__torch__.torch.jit._shape_functions.check_cat_no_zero_dim(%arg0: !torch.list>) -> !torch.none { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %true = torch.constant.bool true %int0 = torch.constant.int 0 + %true = torch.constant.bool true + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int torch.prim.Loop %0, %true, init() { ^bb0(%arg1: !torch.int): @@ -3845,11 +3845,11 @@ module { return %none : !torch.none } func.func @__torch__.torch.jit._shape_functions.legacy_cat_wrap_dim(%arg0: !torch.int, %arg1: !torch.list>) -> !torch.int { - %false = torch.constant.bool false - %true = torch.constant.bool true - %none = torch.constant.none - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %none = torch.constant.none + %true = torch.constant.bool true + %false = torch.constant.bool false %0 = torch.aten.len.t %arg1 : !torch.list> -> !torch.int %1 = torch.derefine %none : !torch.none to !torch.optional %2 = torch.prim.Loop %0, %true, init(%1) { @@ -3892,9 +3892,9 @@ module { return %4 : !torch.int } func.func @__torch__.torch.jit._shape_functions.should_skip(%arg0: !torch.list) -> !torch.bool { - %false = torch.constant.bool false - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false %0 = call @__torch__.torch.jit._shape_functions.numel(%arg0) : (!torch.list) -> !torch.int %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool %2 = torch.prim.If %1 -> (!torch.bool) { @@ -3907,8 +3907,8 @@ module { return %2 : !torch.bool } func.func @__torch__.torch.jit._shape_functions.numel(%arg0: !torch.list) -> !torch.int { - %true = torch.constant.bool true %int1 = torch.constant.int 1 + %true = torch.constant.bool true %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.prim.Loop %0, %true, init(%int1) { ^bb0(%arg1: !torch.int, %arg2: !torch.int): @@ -3919,19 +3919,19 @@ module { return %1 : !torch.int } func.func @__torch__.torch.jit._shape_functions.check_cat_shape_except_dim(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.none { - %str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension" - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %none = torch.constant.none - %str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions" %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: Tensors must have same number of dimensions" + %none = torch.constant.none + %int1 = torch.constant.int 1 + %true = torch.constant.bool true + %str_0 = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool torch.prim.If %2 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %3 = torch.aten.__range_length %int0, %0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int @@ -3946,7 +3946,7 @@ module { torch.prim.If %8 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } torch.prim.If.yield @@ -3959,10 +3959,10 @@ module { } func.func @__torch__.torch.jit._shape_functions.permute(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %int0 = torch.constant.int 0 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " %int1 = torch.constant.int 1 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %true = torch.constant.bool true %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool @@ -4039,10 +4039,10 @@ module { %str_0 = torch.constant.str "AssertionError: invalid shape dimensions" %str_1 = torch.constant.str "AssertionError: only one dimension can be inferred" %int-1 = torch.constant.int -1 - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 %true = torch.constant.bool true + %none = torch.constant.none + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.prim.Loop %0, %true, init(%int1) { ^bb0(%arg2: !torch.int, %arg3: !torch.int): @@ -4131,15 +4131,15 @@ module { return %9 : !torch.list } func.func @__torch__.torch.jit._shape_functions.infer_size_impl(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list { - %str = torch.constant.str "AssertionError: invalid shape" - %false = torch.constant.bool false - %str_0 = torch.constant.str "AssertionError: invalid shape dimensions" - %str_1 = torch.constant.str "AssertionError: only one dimension can be inferred" - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %none = torch.constant.none - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %none = torch.constant.none + %true = torch.constant.bool true + %int-1 = torch.constant.int -1 + %str = torch.constant.str "AssertionError: only one dimension can be inferred" + %str_0 = torch.constant.str "AssertionError: invalid shape dimensions" + %false = torch.constant.bool false + %str_1 = torch.constant.str "AssertionError: invalid shape" %0 = torch.prim.Uninitialized : !torch.int %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %2 = torch.derefine %none : !torch.none to !torch.optional @@ -4150,7 +4150,7 @@ module { %11:2 = torch.prim.If %10 -> (!torch.int, !torch.optional) { %12 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool torch.prim.If %12 -> () { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } else { torch.prim.If.yield @@ -4196,7 +4196,7 @@ module { } %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool torch.prim.If %6 -> () { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } else { torch.prim.If.yield @@ -4214,12 +4214,12 @@ module { return %7 : !torch.list } func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %true = torch.constant.bool true + %int-1 = torch.constant.int -1 %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %2 = torch.aten.ge.int %0, %1 : !torch.int, !torch.int -> !torch.bool @@ -4292,12 +4292,12 @@ module { return %6 : !torch.list } func.func @__torch__.torch.jit._shape_functions.expand_one_unused(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.any) -> !torch.list { - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %str = torch.constant.str "AssertionError: " - %none = torch.constant.none - %true = torch.constant.bool true %int-1 = torch.constant.int -1 + %true = torch.constant.bool true + %none = torch.constant.none + %str = torch.constant.str "AssertionError: " + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %2 = torch.aten.ge.int %0, %1 : !torch.int, !torch.int -> !torch.bool @@ -4373,9 +4373,9 @@ module { %none = torch.constant.none %str = torch.constant.str "AssertionError: " %int0 = torch.constant.int 0 - %false = torch.constant.bool false - %true = torch.constant.bool true %int1 = torch.constant.int 1 + %true = torch.constant.bool true + %false = torch.constant.bool false %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int torch.prim.Loop %1, %true, init() { @@ -4511,22 +4511,22 @@ module { return %3 : !torch.tuple, list> } func.func @__torch__.torch.jit._shape_functions.addmm(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.any, %arg4: !torch.any) -> !torch.list { - %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" + %int2 = torch.constant.int 2 + %none = torch.constant.none + %str = torch.constant.str "AssertionError: " + %str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" %false = torch.constant.bool false %true = torch.constant.bool true - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %str_0 = torch.constant.str "AssertionError: self must be a matrix" - %none = torch.constant.none + %int0 = torch.constant.int 0 %str_1 = torch.constant.str "AssertionError: mat2 must be a matrix" - %str_2 = torch.constant.str "AssertionError: " + %str_2 = torch.constant.str "AssertionError: self must be a matrix" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none torch.prim.If.yield } %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int @@ -4543,7 +4543,7 @@ module { torch.prim.If %6 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %7 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int @@ -4587,8 +4587,8 @@ module { torch.prim.If.yield %false : !torch.bool } torch.prim.If %24 -> () { - %28 = torch.aten.format(%str, %19, %21, %arg5) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %29 = torch.aten.add.str %str_2, %28 : !torch.str, !torch.str -> !torch.str + %28 = torch.aten.format(%str_0, %19, %21, %arg5) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %29 = torch.aten.add.str %str, %28 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %29, %none : !torch.str, !torch.none torch.prim.If.yield } else { @@ -4606,14 +4606,14 @@ module { return %12 : !torch.list } func.func @__torch__.torch.jit._shape_functions.upsample_nearest2d(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.optional> { - %str = torch.constant.str "AssertionError: Either output_size or scale_factors must be presented" - %str_0 = torch.constant.str "AssertionError: " - %str_1 = torch.constant.str "AssertionError: Must specify exactly one of output_size and scale_factors" - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %none = torch.constant.none + %str = torch.constant.str "AssertionError: Must specify exactly one of output_size and scale_factors" + %str_0 = torch.constant.str "AssertionError: " + %str_1 = torch.constant.str "AssertionError: Either output_size or scale_factors must be presented" %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int %2 = torch.aten.append.t %0, %1 : !torch.list, !torch.int -> !torch.list @@ -4626,7 +4626,7 @@ module { torch.prim.If %8 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %9 = torch.aten.len.t %7 : !torch.list -> !torch.int @@ -4651,7 +4651,7 @@ module { torch.prim.If %10 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %11 = torch.aten.len.t %9 : !torch.list -> !torch.int @@ -4675,7 +4675,7 @@ module { %23 = torch.derefine %0 : !torch.list to !torch.optional> torch.prim.If.yield %23 : !torch.optional> } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none %9 = torch.derefine %none : !torch.none to !torch.optional> torch.prim.If.yield %9 : !torch.optional> } @@ -4753,9 +4753,9 @@ module { return %1 : !torch.list } func.func @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list { - %true = torch.constant.bool true - %int9223372036854775807 = torch.constant.int 9223372036854775807 %int1 = torch.constant.int 1 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %true = torch.constant.bool true %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int %2 = torch.prim.ListConstruct : () -> !torch.list @@ -4783,20 +4783,20 @@ module { return %2 : !torch.list } func.func @__torch__.torch.jit._shape_functions.bmm(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "AssertionError: mismatching contracting dimension" - %str_0 = torch.constant.str "AssertionError: mismatching batch dimension" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: bmm only supports 3D tensors" - %int3 = torch.constant.int 3 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %int3 = torch.constant.int 3 + %str = torch.constant.str "AssertionError: bmm only supports 3D tensors" + %none = torch.constant.none + %str_0 = torch.constant.str "AssertionError: mismatching batch dimension" + %str_1 = torch.constant.str "AssertionError: mismatching contracting dimension" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int @@ -4804,7 +4804,7 @@ module { torch.prim.If %3 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int @@ -4822,7 +4822,7 @@ module { torch.prim.If %9 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int @@ -4838,10 +4838,10 @@ module { } func.func @__torch__.torch.jit._shape_functions.topk(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.tuple, list> { %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "k ({}) is too big for dimension {} of size {}" %int0 = torch.constant.int 0 + %str = torch.constant.str "k ({}) is too big for dimension {} of size {}" + %str_0 = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool %2 = torch.prim.If %1 -> (!torch.list) { @@ -4854,8 +4854,8 @@ module { torch.prim.If.yield } else { %9 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.format(%str_0, %arg1, %arg2, %9) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %11 = torch.aten.add.str %str, %10 : !torch.str, !torch.str -> !torch.str + %10 = torch.aten.format(%str, %arg1, %arg2, %9) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %11 = torch.aten.add.str %str_0, %10 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %11, %none : !torch.str, !torch.none torch.prim.If.yield } @@ -4874,14 +4874,14 @@ module { return %3 : !torch.tuple, list> } func.func @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int) -> !torch.tuple, list> { - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %true = torch.constant.bool true + %int-1 = torch.constant.int -1 %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool @@ -4968,10 +4968,10 @@ module { } func.func @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list, list> { %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int @@ -5031,13 +5031,13 @@ module { return %3 : !torch.tuple, list, list> } func.func @__torch__.torch.jit._shape_functions.broadcast_three(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %false = torch.constant.bool false - %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %str_0 = torch.constant.str "AssertionError: " %none = torch.constant.none + %str = torch.constant.str "AssertionError: " + %str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" + %false = torch.constant.bool false + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int @@ -5078,8 +5078,8 @@ module { torch.prim.If.yield %false : !torch.bool } torch.prim.If %20 -> () { - %24 = torch.aten.format(%str, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %25 = torch.aten.add.str %str_0, %24 : !torch.str, !torch.str -> !torch.str + %24 = torch.aten.format(%str_0, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %25 = torch.aten.add.str %str, %24 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %25, %none : !torch.str, !torch.none torch.prim.If.yield } else { @@ -5134,8 +5134,8 @@ module { torch.prim.If.yield %false : !torch.bool } torch.prim.If %20 -> () { - %24 = torch.aten.format(%str, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %25 = torch.aten.add.str %str_0, %24 : !torch.str, !torch.str -> !torch.str + %24 = torch.aten.format(%str_0, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %25 = torch.aten.add.str %str, %24 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %25, %none : !torch.str, !torch.none torch.prim.If.yield } else { @@ -5153,13 +5153,13 @@ module { return %7 : !torch.list } func.func @__torch__.torch.jit._shape_functions.broadcast_one_three(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.list) -> !torch.list { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %false = torch.constant.bool false - %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %str_0 = torch.constant.str "AssertionError: " %none = torch.constant.none + %str = torch.constant.str "AssertionError: " + %str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" + %false = torch.constant.bool false + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg2 : !torch.list -> !torch.int %2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int @@ -5200,8 +5200,8 @@ module { torch.prim.If.yield %false : !torch.bool } torch.prim.If %16 -> () { - %20 = torch.aten.format(%str, %11, %13, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %21 = torch.aten.add.str %str_0, %20 : !torch.str, !torch.str -> !torch.str + %20 = torch.aten.format(%str_0, %11, %13, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %21, %none : !torch.str, !torch.none torch.prim.If.yield } else { @@ -5219,19 +5219,19 @@ module { return %3 : !torch.list } func.func @__torch__.torch.jit._shape_functions.broadcast_inplace(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %false = torch.constant.bool false %true = torch.constant.bool true - %none = torch.constant.none - %str_0 = torch.constant.str "AssertionError: " - %str_1 = torch.constant.str "The dims of tensor b ({}) must be less than or equal tothe dims of tensor a ({}) " - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %str = torch.constant.str "The dims of tensor b ({}) must be less than or equal tothe dims of tensor a ({}) " + %str_0 = torch.constant.str "AssertionError: " + %none = torch.constant.none + %false = torch.constant.bool false + %str_1 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %2 = torch.aten.gt.int %1, %0 : !torch.int, !torch.int -> !torch.bool torch.prim.If %2 -> () { - %5 = torch.aten.format(%str_1, %1, %0) : !torch.str, !torch.int, !torch.int -> !torch.str + %5 = torch.aten.format(%str, %1, %0) : !torch.str, !torch.int, !torch.int -> !torch.str %6 = torch.aten.add.str %str_0, %5 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %6, %none : !torch.str, !torch.none torch.prim.If.yield @@ -5258,7 +5258,7 @@ module { torch.prim.If.yield %false : !torch.bool } torch.prim.If %11 -> () { - %12 = torch.aten.format(%str, %7, %9, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %12 = torch.aten.format(%str_1, %7, %9, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str %13 = torch.aten.add.str %str_0, %12 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %13, %none : !torch.str, !torch.none torch.prim.If.yield @@ -5284,8 +5284,8 @@ module { return %1 : !torch.list } func.func @__torch__.torch.jit._shape_functions.nonzero_upper_bound(%arg0: !torch.list) -> !torch.list { - %int1 = torch.constant.int 1 %true = torch.constant.bool true + %int1 = torch.constant.int 1 %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.prim.Loop %0, %true, init(%int1) { ^bb0(%arg1: !torch.int, %arg2: !torch.int): @@ -5570,9 +5570,9 @@ module { return %1 : !torch.list } func.func @__torch__._reduce_along_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list { - %true = torch.constant.bool true - %int9223372036854775807 = torch.constant.int 9223372036854775807 %int1 = torch.constant.int 1 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %true = torch.constant.bool true %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int %2 = torch.prim.ListConstruct : () -> !torch.list @@ -5627,14 +5627,14 @@ module { return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.t"(%arg0: !torch.list) -> !torch.list { - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 %0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %int0, %int1) : (!torch.list, !torch.int, !torch.int) -> !torch.list return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.numpy_T"(%arg0: !torch.list) -> !torch.list { - %true = torch.constant.bool true %int0 = torch.constant.int 0 + %true = torch.constant.bool true %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int torch.prim.Loop %1, %true, init() { @@ -5660,20 +5660,20 @@ module { return %2 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.bmm"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "AssertionError: mismatching contracting dimension" - %str_0 = torch.constant.str "AssertionError: mismatching batch dimension" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: bmm only supports 3D tensors" - %int3 = torch.constant.int 3 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %int3 = torch.constant.int 3 + %str = torch.constant.str "AssertionError: bmm only supports 3D tensors" + %none = torch.constant.none + %str_0 = torch.constant.str "AssertionError: mismatching batch dimension" + %str_1 = torch.constant.str "AssertionError: mismatching contracting dimension" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int @@ -5681,7 +5681,7 @@ module { torch.prim.If %3 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int @@ -5699,7 +5699,7 @@ module { torch.prim.If %9 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int @@ -5709,20 +5709,20 @@ module { return %13 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.baddbmm"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list { - %str = torch.constant.str "AssertionError: mismatching contracting dimension" - %str_0 = torch.constant.str "AssertionError: mismatching batch dimension" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: baddbmm only supports 3D tensors" - %int3 = torch.constant.int 3 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %int3 = torch.constant.int 3 + %str = torch.constant.str "AssertionError: baddbmm only supports 3D tensors" + %none = torch.constant.none + %str_0 = torch.constant.str "AssertionError: mismatching batch dimension" + %str_1 = torch.constant.str "AssertionError: mismatching contracting dimension" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int @@ -5730,7 +5730,7 @@ module { torch.prim.If %3 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int @@ -5748,7 +5748,7 @@ module { torch.prim.If %9 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } %10 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int @@ -5850,21 +5850,21 @@ module { return %0 : !torch.list } func.func @__torch__.avg_pool2d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list { - %int-1 = torch.constant.int -1 - %int-2 = torch.constant.int -2 - %int-3 = torch.constant.int -3 - %int-4 = torch.constant.int -4 - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "AssertionError: avg_pool2d: padding must be either be a single int, or a tuple of two ints" - %str_1 = torch.constant.str "AssertionError: avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints" - %none = torch.constant.none - %str_2 = torch.constant.str "AssertionError: avg_pool2d: kernel_size must either be a single int, or a tuple of two ints" - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int0 = torch.constant.int 0 - %int3 = torch.constant.int 3 %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %true = torch.constant.bool true + %str = torch.constant.str "AssertionError: avg_pool2d: kernel_size must either be a single int, or a tuple of two ints" + %none = torch.constant.none + %str_0 = torch.constant.str "AssertionError: avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + %str_1 = torch.constant.str "AssertionError: avg_pool2d: padding must be either be a single int, or a tuple of two ints" + %str_2 = torch.constant.str "AssertionError: " + %int-4 = torch.constant.int -4 + %int-3 = torch.constant.int -3 + %int-2 = torch.constant.int -2 + %int-1 = torch.constant.int -1 %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool %2 = torch.prim.If %1 -> (!torch.bool) { @@ -5877,7 +5877,7 @@ module { torch.prim.If %2 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int @@ -5908,7 +5908,7 @@ module { torch.prim.If %10 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int @@ -5946,7 +5946,7 @@ module { torch.prim.If %19 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none torch.prim.If.yield } %20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int @@ -5970,7 +5970,7 @@ module { torch.prim.If %26 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none torch.prim.If.yield } %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int @@ -6246,17 +6246,17 @@ module { return %1 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.topk"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple, list> { + %str = torch.constant.str "k ({}) is too big for dimension {} of size {}" + %str_0 = torch.constant.str "AssertionError: " %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "k ({}) is too big for dimension {} of size {}" %0 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool torch.prim.If %1 -> () { torch.prim.If.yield } else { %4 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.format(%str_0, %arg1, %arg2, %4) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %6 = torch.aten.add.str %str, %5 : !torch.str, !torch.str -> !torch.str + %5 = torch.aten.format(%str, %arg1, %arg2, %4) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str + %6 = torch.aten.add.str %str_0, %5 : !torch.str, !torch.str -> !torch.str torch.prim.RaiseException %6, %none : !torch.str, !torch.none torch.prim.If.yield } @@ -6286,10 +6286,16 @@ module { %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.slice_scatter"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.list { + return %arg0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.select.int"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.select_scatter"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list { + return %arg0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.index_select"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list return %0 : !torch.list @@ -6303,14 +6309,14 @@ module { return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> { - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %true = torch.constant.bool true + %int-1 = torch.constant.int -1 %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool @@ -6401,11 +6407,11 @@ module { return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.native_layer_norm"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple, list, list> { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %str = torch.constant.str "AssertionError: " + %none = torch.constant.none + %true = torch.constant.bool true %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int @@ -6434,8 +6440,8 @@ module { return %7 : !torch.tuple, list, list> } func.func @"__torch_mlir_shape_fn.aten.native_batch_norm"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple, list, list> { - %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 %0 = torch.prim.If %arg5 -> (!torch.tuple, list, list>) { %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int %2 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list @@ -6456,20 +6462,20 @@ module { return %0 : !torch.list } func.func @__torch__.pad_shape_fn(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %true = torch.constant.bool true - %str = torch.constant.str "AssertionError: Number of padded dimensions must be less than or equal to the input dimension" - %none = torch.constant.none - %str_0 = torch.constant.str "AssertionError: Must have paired low-high pad amount values" - %int2 = torch.constant.int 2 - %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %str = torch.constant.str "AssertionError: Must have paired low-high pad amount values" + %none = torch.constant.none + %str_0 = torch.constant.str "AssertionError: Number of padded dimensions must be less than or equal to the input dimension" + %true = torch.constant.bool true %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %1 = torch.aten.remainder.int %0, %int2 : !torch.int, !torch.int -> !torch.int %2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool torch.prim.If %2 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none + torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int @@ -6479,7 +6485,7 @@ module { torch.prim.If %6 -> () { torch.prim.If.yield } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none torch.prim.If.yield } %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int @@ -6636,8 +6642,8 @@ module { return %none : !torch.none } func.func @"__torch_mlir_shape_fn.aten.linalg_vector_norm"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { - %true = torch.constant.bool true %none = torch.constant.none + %true = torch.constant.bool true %0 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool %1 = torch.prim.If %0 -> (!torch.list) { %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index bab82e133..e2f0a333a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -904,9 +904,15 @@ def aten〇batch_norm(input: List[int], weight: Optional[List[int]], bias: Optio def aten〇slice〇Tensor(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) +def aten〇slice_scatter(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: + return self + def aten〇select〇int(self: List[int], dim: int, index: int) -> List[int]: return upstream_shape_functions.select(self, dim, index) +def aten〇select_scatter(self: List[int], src: List[int], dim: int, index: int) -> List[int]: + return self + def aten〇index_select(self: List[int], dim: int, index: List[int]) -> List[int]: return upstream_shape_functions.index_select(self, dim, index) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 2c43cd900..b7b1fa770 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -437,6 +437,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") emit("aten::select.int : (Tensor, int, int) -> (Tensor)") + emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::sum : (Tensor, int?) -> (Tensor)") @@ -455,6 +456,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)") emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)") + emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 7e894abac..3a56826ff 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -232,3 +232,112 @@ def SelectIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(10, (5,5))) # ============================================================================== + +# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1). +# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index). +class SliceScatterModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, src): + return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1) + +@register_test_case(module_factory=lambda: SliceScatterModule()) +def SliceScatterModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 8), tu.rand(6, 1)) + +class SliceScatterZeroDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, src): + return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 0, end = 1, step = 1) + + +@register_test_case(module_factory=lambda: SliceScatterZeroDimModule()) +def SliceScatterZeroDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 8), tu.rand(1, 8)) + +class SliceScatterStepVariationModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, src): + return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 2) + + +@register_test_case(module_factory=lambda: SliceScatterStepVariationModule()) +def SliceScatterStepVariationModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 8), tu.rand(6, 1)) + +class SliceScatterStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([6, 8], torch.float32, True), + ([6, 1], torch.float32, True), + ]) + def forward(self, x, src): + return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1) + + +@register_test_case(module_factory=lambda: SliceScatterStaticModule()) +def SliceScatterStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 8), tu.rand(6, 1)) + +class SelectScatterModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, src): + return torch.ops.aten.select_scatter(x, src, dim = 0, index = 0) + + +@register_test_case(module_factory=lambda: SelectScatterModule()) +def SelectScattertModule_basic(module, tu: TestUtils): + module.forward(torch.rand(6, 8, 5), torch.rand(8, 5)) + +class SelectScatterStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([6, 8, 5], torch.float32, True), + ([6, 5], torch.float32, True), + ]) + def forward(self, x, src): + return torch.ops.aten.select_scatter(x, src, dim = 1, index = 0) + + +@register_test_case(module_factory=lambda: SelectScatterStaticModule()) +def SelectScattertStaticModule_basic(module, tu: TestUtils): + module.forward(torch.rand(6, 8, 5), torch.rand(6, 5)) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index c97a31f4f..1a531d0f7 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1087,3 +1087,21 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int %2 = torch.aten.repeat %arg0, %1 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?],f32> return %2 : !torch.vtensor<[?,?,?],f32> } + +// ----- +// CHECK-LABEL: func @torch.aten.select_scatter +// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?],f32>, %[[SRC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-NEXT: %[[START:.*]] = torch.constant.int 0 +// CHECK-NEXT: %[[DIM:.*]] = torch.constant.int 1 +// CHECK-NEXT: %[[STEP:.*]] = torch.constant.int 1 +// CHECK-NEXT: %[[END:.*]] = torch.aten.add.int %[[START]], %[[STEP]] +// CHECK-NEXT: %[[UNSQUEEZE_SRC:.*]] = torch.aten.unsqueeze %[[SRC]], %[[DIM]] +// CHECK-NEXT: %[[SLICE_SCATTER:.*]] = torch.aten.slice_scatter %[[SELF]], %[[UNSQUEEZE_SRC]], %[[DIM]], %[[START]], %[[END]], %[[STEP]] +// CHECK-NEXT: return %[[SLICE_SCATTER]] +// CHECK-NEXT: } +func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.select_scatter %arg0, %arg1, %int1, %int0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +}