Remove hacky aten.select.int lowering code

pull/588/head
Yi Zhang 2022-02-11 14:34:05 -05:00
parent 756b75fb2d
commit ce4d6d1f83
4 changed files with 110 additions and 93 deletions

View File

@ -15,7 +15,6 @@
# to the backend contract.
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"IouOfModule_basic",
"TableBatchEmbeddingModule_basic",
"MobilenetV2Module_basic",
"MobilenetV3Module_basic",

View File

@ -3521,7 +3521,6 @@ public:
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
int64_t resultRank = resultType.getRank();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@ -3592,21 +3591,7 @@ public:
Value result = rewriter.create<tensor::ExtractSliceOp>(
loc, input, offsets, resultShape, strides);
// TODO: This code is for selectOp, remove once squeeze dim is added
if (resultRank < inputType.getRank()) {
SmallVector<ReassociationIndices> reassociation(resultRank);
int64_t resultIdx = 0;
for (auto i : llvm::seq<int64_t>(0, inputType.getRank())) {
if (resultIdx < resultRank)
reassociation[resultIdx].push_back(i);
if (i != dim)
resultIdx++;
}
result =
rewriter.create<tensor::CollapseShapeOp>(loc, result, reassociation);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};

View File

@ -173,14 +173,22 @@ public:
LogicalResult matchAndRewrite(AtenSelectIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value start = op.index();
Value dim = op.dim();
Value self = op.self();
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value end =
rewriter.create<AtenAddIntOp>(loc, one.getType(), op.index(), one);
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(op, op.getResult().getType(),
op.self(), op.dim(),
op.index(), end, one);
Value startPlusOne =
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
Value slice = rewriter.create<AtenSliceTensorOp>(
loc, computeReductionType(rewriter, op, self, dim, /*keepDim=*/true),
op.self(), dim, start, startPlusOne, /*step=*/one);
// `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after
// slicing, while `aten.select.int` does.
rewriter.replaceOpWithNewOp<AtenSqueezeDimOp>(op, op.getResult().getType(),
slice, op.dim());
return success();
}
};

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @matmul_no_decompose
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
// CHECK-LABEL: func @matmul_no_decompose
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
@ -10,16 +10,16 @@ func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.
// -----
// CHECK-LABEL: func @matmul_decompose_2d
// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
// CHECK-LABEL: func @matmul_decompose_2d
// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @matmul_decompose_3d(
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
// CHECK-LABEL: func @matmul_decompose_3d(
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
@ -31,10 +31,10 @@ func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[KEEP_DIM0:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[KEEP_DIM0]] :
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[KEEP_DIM0]] :
// CHECK-SAME: !torch.tensor<[2,3],f32>, !torch.int, !torch.bool -> !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.float -> !torch.tensor<[2,3],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
@ -58,10 +58,10 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1
// CHECK: %[[TRU:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.bool ->
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.bool ->
// CHECK-SAME: !torch.tensor<[2,1],f32>, !torch.tensor<[2,1],si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
// CHECK-SAME: !torch.tensor<[2,1],f32>, !torch.float -> !torch.tensor<[2,3],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
@ -88,7 +88,7 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool ->
// CHECK-SAME: !torch.tensor<[?,1],f32>, !torch.tensor<[?,1],si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[?,?],f32>,
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[?,?],f32>,
// CHECK-SAME: !torch.tensor<[?,1],f32>, !torch.float -> !torch.tensor<[?,?],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
@ -112,7 +112,7 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1
// CHECK: %[[TRU:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<*,f32>, !torch.int, !torch.bool
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<*,f32>, !torch.int, !torch.bool
// CHECK-SAME: -> !torch.tensor<*,f32>, !torch.tensor<*,si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<*,f32>, !torch.tensor<*,f32>,
@ -181,12 +181,13 @@ func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
}
// -----
// CHECK-LABEL: func @torch.aten.argmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[CST0]], %[[TRUE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],f32>, !torch.vtensor<[1,?],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[1,?],si64>
// CHECK-LABEL: func @torch.aten.argmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[CST0]], %[[TRUE]] :
// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],f32>, !torch.vtensor<[1,?],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[1,?],si64>
func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
%int0 = torch.constant.int 0
%true = torch.constant.bool true
@ -195,15 +196,17 @@ func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?
}
// -----
// CHECK-LABEL: func @torch.aten.argmax$reduceall(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[INP]], %[[CST0]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[],si64>
// CHECK-LABEL: func @torch.aten.argmax$reduceall(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[INP]], %[[CST0]], %[[CST1]] :
// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] :
// CHECK-SAME: !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[],si64>
func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
%none = torch.constant.none
%false = torch.constant.bool false
@ -214,7 +217,8 @@ func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
// -----
// CHECK-LABEL: func @torch.aten.square(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[SQUARE]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
@ -312,12 +316,12 @@ func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtenso
}
// -----
// CHECK-LABEL: func @torch.aten._unsafe_view$static
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,512,32],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return
// CHECK-LABEL: func @torch.aten._unsafe_view$static
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,512,32],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return
func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !torch.vtensor<[1,2,256,32],f32> {
%c1 = torch.constant.int 1
%c2 = torch.constant.int 2
@ -329,12 +333,12 @@ func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !
}
// -----
// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return
// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return
func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[512,32],f32> {
%c256 = torch.constant.int 512
%c32 = torch.constant.int 32
@ -344,24 +348,26 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to
}
// -----
// CHECK-LABEL: func @_log.softmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[INT0]], %[[TRUE]] : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?,?],f32>, !torch.vtensor<[1,?,?],si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[INP]], %[[VAL]], %[[FLOAT1]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[PRIM:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[TRU:.*]] = torch.constant.bool true
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SUM_DIM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[PRIM]], %[[TRU]], %[[NONE]] : !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[1,?,?],f32>
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM_DIM]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor<[?,?,?],f32>
// CHECK: %[[LOG:.*]] = torch.aten.log %[[CAST]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[?,?,?],f32>
func @_log.softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> {
// CHECK-LABEL: func @torch.aten._log_softmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[INT0]], %[[TRUE]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?,?],f32>, !torch.vtensor<[1,?,?],si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[INP]], %[[VAL]], %[[FLOAT1]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[PRIM:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[TRU:.*]] = torch.constant.bool true
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SUM_DIM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[PRIM]], %[[TRU]], %[[NONE]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[1,?,?],f32>
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM_DIM]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor<[?,?,?],f32>
// CHECK: %[[LOG:.*]] = torch.aten.log %[[CAST]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> {
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%0 = torch.aten._log_softmax %arg0, %int0, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],f32>
@ -369,23 +375,42 @@ func @_log.softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vt
}
// -----
// CHECK-LABEL: func @bernoulli
// CHECK-LABEL: func @torch.aten.bernoulli
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[FLOAT0_5:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE0:.*]] = torch.constant.none
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[INP]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE0]] : !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[GT:.*]] = torch.aten.lt.Scalar %[[UNF]], %[[FLOAT0_5]] : !torch.vtensor<[?,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[GT]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE0]] : !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @bernoulli(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[FLOAT0_5:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE0:.*]] = torch.constant.none
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[INP]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE0]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[GT:.*]] = torch.aten.lt.Scalar %[[UNF]], %[[FLOAT0_5]] : !torch.vtensor<[?,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],i1>
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[GT]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE0]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor {
%none = torch.constant.none
%0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[?,?,?],f32>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @torch.aten.select.int(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[END:.*]] = torch.aten.add.int %[[CST0]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[T]], %[[CST0]], %[[CST0]], %[[END]], %[[CST1]] :
// CHECK-SAME: !torch.vtensor<[?,?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64>
// CHECK: %[[SELECT:.*]] = torch.aten.squeeze.dim %[[SLICE]], %[[CST0]] :
// CHECK-SAME: !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[?],si64>
// CHECK: return %[[SELECT]] : !torch.vtensor<[?],si64>
func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> {
%int0 = torch.constant.int 0
%0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64>
return %0 : !torch.vtensor<[?],si64>
}