From ce4d6d1f8332510e8c0e831a409942d7e6d75377 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Fri, 11 Feb 2022 14:34:05 -0500 Subject: [PATCH] Remove hacky aten.select.int lowering code --- e2e_testing/torchscript/xfail_sets.py | 1 - .../TorchToLinalg/TorchToLinalg.cpp | 15 -- .../Torch/Transforms/DecomposeComplexOps.cpp | 18 +- test/Dialect/Torch/decompose-complex-ops.mlir | 169 ++++++++++-------- 4 files changed, 110 insertions(+), 93 deletions(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 540454953..fb19a058d 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -15,7 +15,6 @@ # to the backend contract. COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", - "IouOfModule_basic", "TableBatchEmbeddingModule_basic", "MobilenetV2Module_basic", "MobilenetV3Module_basic", diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 205f67603..9f6c911a9 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -3521,7 +3521,6 @@ public: RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); - int64_t resultRank = resultType.getRank(); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); @@ -3592,21 +3591,7 @@ public: Value result = rewriter.create( loc, input, offsets, resultShape, strides); - // TODO: This code is for selectOp, remove once squeeze dim is added - if (resultRank < inputType.getRank()) { - SmallVector reassociation(resultRank); - int64_t resultIdx = 0; - for (auto i : llvm::seq(0, inputType.getRank())) { - if (resultIdx < resultRank) - reassociation[resultIdx].push_back(i); - if (i != dim) - resultIdx++; - } - result = - rewriter.create(loc, result, reassociation); - } rewriter.replaceOpWithNewOp(op, resultType, result); - return success(); } }; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 43d789148..52f47b6bc 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -173,14 +173,22 @@ public: LogicalResult matchAndRewrite(AtenSelectIntOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); + Value start = op.index(); + Value dim = op.dim(); + Value self = op.self(); + Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value end = - rewriter.create(loc, one.getType(), op.index(), one); - rewriter.replaceOpWithNewOp(op, op.getResult().getType(), - op.self(), op.dim(), - op.index(), end, one); + Value startPlusOne = + rewriter.create(loc, one.getType(), start, one); + Value slice = rewriter.create( + loc, computeReductionType(rewriter, op, self, dim, /*keepDim=*/true), + op.self(), dim, start, startPlusOne, /*step=*/one); + // `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after + // slicing, while `aten.select.int` does. + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + slice, op.dim()); return success(); } }; diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 10f57ce11..b5237ce5a 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1,7 +1,7 @@ // RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s -// CHECK-LABEL: func @matmul_no_decompose -// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor +// CHECK-LABEL: func @matmul_no_decompose +// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor return %0 : !torch.tensor @@ -10,16 +10,16 @@ func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch. // ----- -// CHECK-LABEL: func @matmul_decompose_2d -// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor +// CHECK-LABEL: func @matmul_decompose_2d +// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor return %0 : !torch.tensor } // ----- -// CHECK-LABEL: func @matmul_decompose_3d( -// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor +// CHECK-LABEL: func @matmul_decompose_3d( +// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor return %0 : !torch.tensor @@ -31,10 +31,10 @@ func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten // CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> { // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[KEEP_DIM0:.*]] = torch.constant.bool true -// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[KEEP_DIM0]] : +// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[KEEP_DIM0]] : // CHECK-SAME: !torch.tensor<[2,3],f32>, !torch.int, !torch.bool -> !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],si64> // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>, +// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>, // CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.float -> !torch.tensor<[2,3],f32> // CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32> // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list @@ -58,10 +58,10 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DIM:.*]] = torch.constant.int 1 // CHECK: %[[TRU:.*]] = torch.constant.bool true -// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.bool -> +// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.bool -> // CHECK-SAME: !torch.tensor<[2,1],f32>, !torch.tensor<[2,1],si64> // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>, +// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>, // CHECK-SAME: !torch.tensor<[2,1],f32>, !torch.float -> !torch.tensor<[2,3],f32> // CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32> // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list @@ -88,7 +88,7 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten // CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool -> // CHECK-SAME: !torch.tensor<[?,1],f32>, !torch.tensor<[?,1],si64> // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[?,?],f32>, +// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[?,?],f32>, // CHECK-SAME: !torch.tensor<[?,1],f32>, !torch.float -> !torch.tensor<[?,?],f32> // CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32> // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list @@ -112,7 +112,7 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DIM:.*]] = torch.constant.int 1 // CHECK: %[[TRU:.*]] = torch.constant.bool true -// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<*,f32>, !torch.int, !torch.bool +// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<*,f32>, !torch.int, !torch.bool // CHECK-SAME: -> !torch.tensor<*,f32>, !torch.tensor<*,si64> // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 // CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<*,f32>, !torch.tensor<*,f32>, @@ -181,12 +181,13 @@ func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> { } // ----- -// CHECK-LABEL: func @torch.aten.argmax( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> { -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[CST0]], %[[TRUE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],f32>, !torch.vtensor<[1,?],si64> -// CHECK: return %[[IND]] : !torch.vtensor<[1,?],si64> +// CHECK-LABEL: func @torch.aten.argmax( +// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> { +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[CST0]], %[[TRUE]] : +// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],f32>, !torch.vtensor<[1,?],si64> +// CHECK: return %[[IND]] : !torch.vtensor<[1,?],si64> func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> { %int0 = torch.constant.int 0 %true = torch.constant.bool true @@ -195,15 +196,17 @@ func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,? } // ----- -// CHECK-LABEL: func @torch.aten.argmax$reduceall( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[INP]], %[[CST0]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32> -// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64> -// CHECK: return %[[IND]] : !torch.vtensor<[],si64> +// CHECK-LABEL: func @torch.aten.argmax$reduceall( +// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[INP]], %[[CST0]], %[[CST1]] : +// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] : +// CHECK-SAME: !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64> +// CHECK: return %[[IND]] : !torch.vtensor<[],si64> func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> { %none = torch.constant.none %false = torch.constant.bool false @@ -214,7 +217,8 @@ func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // ----- // CHECK-LABEL: func @torch.aten.square( // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[SQUARE]] : !torch.vtensor<[?,?,?],f32> func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { %0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> @@ -312,12 +316,12 @@ func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtenso } // ----- -// CHECK-LABEL: func @torch.aten._unsafe_view$static -// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,512,32],f32>) -// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct -// CHECK-NOT: torch.aten._unsafe_view -// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]] -// CHECK-NEXT: return +// CHECK-LABEL: func @torch.aten._unsafe_view$static +// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,512,32],f32>) +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct +// CHECK-NOT: torch.aten._unsafe_view +// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]] +// CHECK-NEXT: return func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !torch.vtensor<[1,2,256,32],f32> { %c1 = torch.constant.int 1 %c2 = torch.constant.int 2 @@ -329,12 +333,12 @@ func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> ! } // ----- -// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic -// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct -// CHECK-NOT: torch.aten._unsafe_view -// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]] -// CHECK-NEXT: return +// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic +// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct +// CHECK-NOT: torch.aten._unsafe_view +// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]] +// CHECK-NEXT: return func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[512,32],f32> { %c256 = torch.constant.int 512 %c32 = torch.constant.int 32 @@ -344,24 +348,26 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to } // ----- -// CHECK-LABEL: func @_log.softmax( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[INT0]], %[[TRUE]] : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?,?],f32>, !torch.vtensor<[1,?,?],si64> -// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[INP]], %[[VAL]], %[[FLOAT1]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[PRIM:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list -// CHECK: %[[TRU:.*]] = torch.constant.bool true -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[SUM_DIM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[PRIM]], %[[TRU]], %[[NONE]] : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,?,?],f32> -// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM_DIM]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor<[?,?,?],f32> -// CHECK: %[[LOG:.*]] = torch.aten.log %[[CAST]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[LOG]] : !torch.vtensor<[?,?,?],f32> -func @_log.softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> { +// CHECK-LABEL: func @torch.aten._log_softmax( +// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[INT0]], %[[TRUE]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?,?],f32>, !torch.vtensor<[1,?,?],si64> +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[INP]], %[[VAL]], %[[FLOAT1]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[PRIM:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list +// CHECK: %[[TRU:.*]] = torch.constant.bool true +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[SUM_DIM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[PRIM]], %[[TRU]], %[[NONE]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,?,?],f32> +// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM_DIM]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor<[?,?,?],f32> +// CHECK: %[[LOG:.*]] = torch.aten.log %[[CAST]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[LOG]] : !torch.vtensor<[?,?,?],f32> +func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> { %int0 = torch.constant.int 0 %false = torch.constant.bool false %0 = torch.aten._log_softmax %arg0, %int0, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],f32> @@ -369,23 +375,42 @@ func @_log.softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vt } // ----- -// CHECK-LABEL: func @bernoulli +// CHECK-LABEL: func @torch.aten.bernoulli // CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[FLOAT0_5:.*]] = torch.constant.float 5.000000e-01 -// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE0:.*]] = torch.constant.none -// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[INP]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE0]] : !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[GT:.*]] = torch.aten.lt.Scalar %[[UNF]], %[[FLOAT0_5]] : !torch.vtensor<[?,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],i1> -// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[GT]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE0]] : !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func @bernoulli(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[FLOAT0_5:.*]] = torch.constant.float 5.000000e-01 +// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE0:.*]] = torch.constant.none +// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[INP]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE0]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[GT:.*]] = torch.aten.lt.Scalar %[[UNF]], %[[FLOAT0_5]] : !torch.vtensor<[?,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],i1> +// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[GT]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE0]] : +// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor +// CHECK: return %[[CAST]] : !torch.vtensor +func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor { %none = torch.constant.none %0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[?,?,?],f32> %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor return %1 : !torch.vtensor } + +// ----- +// CHECK-LABEL: func @torch.aten.select.int( +// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> { +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[END:.*]] = torch.aten.add.int %[[CST0]], %[[CST1]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[T]], %[[CST0]], %[[CST0]], %[[END]], %[[CST1]] : +// CHECK-SAME: !torch.vtensor<[?,?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64> +// CHECK: %[[SELECT:.*]] = torch.aten.squeeze.dim %[[SLICE]], %[[CST0]] : +// CHECK-SAME: !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[?],si64> +// CHECK: return %[[SELECT]] : !torch.vtensor<[?],si64> +func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> { + %int0 = torch.constant.int 0 + %0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64> + return %0 : !torch.vtensor<[?],si64> +}