From 4aad5ccf39d7d1e7ddcb8584acac50202e52e91e Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Wed, 23 Nov 2022 15:02:41 +0800 Subject: [PATCH] fix #1626 return type mismatch (#1634) --- e2e_testing/xfail_sets.py | 12 +- lib/Conversion/TorchToMhlo/Reduction.cpp | 38 ++-- test/Conversion/TorchToMhlo/reduction.mlir | 239 --------------------- 3 files changed, 30 insertions(+), 259 deletions(-) delete mode 100644 test/Conversion/TorchToMhlo/reduction.mlir diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 6281ff9dc..c874a7820 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -115,10 +115,18 @@ MHLO_PASS_SET = { "MatmulSingleDynamicBatchDim_basic", "Matmul_3d", "Matmul_4d", + "MeanDimEmptyDimModule_basic", "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "MeanLargeInputModule_basic", + "MeanModule_basic", "MmTanhModule_basic", "Mv_basic", + "PrimsConvertElementTypeModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", "ReduceSumDimIntListDtypeFloatModule_basic", "ReduceSumDimIntListDtypeIntModule_basic", "ReduceSumDimIntListKeepDimFloatModule_basic", @@ -136,7 +144,6 @@ MHLO_PASS_SET = { "LiftFreshCopyModule_basic", "Mlp2LayerModuleNoBias_basic", "NumelModule_basic", - "ReduceSumDimIntListEmptyDimModule_basic", "SqueezeModule_allUnitDim", "SqueezeDimModule_unitDim", "ViewCollapseOnesMiddleModule_basic", @@ -149,9 +156,6 @@ MHLO_PASS_SET = { "ViewTwoToThreeStaticModule_basic", "ViewExpandOnesMiddleOppModule_basic", "ViewOffsetBackwardTestStaticModule_basic", - "MeanModule_basic", - "MeanDynamicSizesModule_basic", - "MeanDimEmptyDimModule_basic", "NumToTensorFloatModule_basic", "AtenToDeviceModule_basic", "AvgPool2dStaticModule_basic", diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp index 589c97ca4..1f4214aef 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -347,15 +347,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().dyn_cast(); + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); if (!inputTy) { return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } - auto dtype = adaptor.dtype(); - if (!dtype.getType().isa()) { - auto dstElemTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast() - .getElementType(); + if (inputTy.getElementType() != outTy.getElementType()) { + // Use output element type as computation type. + auto dstElemTy = outTy.getElementType(); input = rewriter.create(op->getLoc(), input, dstElemTy); inputTy = input.getType().dyn_cast(); } @@ -376,11 +376,11 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( for (int64_t i = 0; i < inputTy.getRank(); i++) { dims.push_back(i); } - Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); + llvm::sort(dims.begin(), dims.end()); auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); @@ -401,7 +401,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( rewriter.create(op->getLoc(), addResult); } - rewriter.replaceOp(op, mhloReduceOp.getResults()); + rewriter.replaceOpWithNewOp(op, outTy, + mhloReduceOp.getResults()); return success(); } } // namespace @@ -438,6 +439,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); + llvm::sort(dims.begin(), dims.end()); auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); @@ -458,7 +460,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( rewriter.create(op->getLoc(), maxResult); } - rewriter.replaceOp(op, mhloReduceOp.getResults()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + mhloReduceOp.getResults()); return success(); } } // namespace @@ -471,15 +475,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().dyn_cast(); + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); if (!inputTy) { return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } - auto dtype = adaptor.dtype(); - if (!dtype.getType().isa()) { - auto dstElemTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast() - .getElementType(); + if (inputTy.getElementType() != outTy.getElementType()) { + // Use output element type as computation type. + auto dstElemTy = outTy.getElementType(); input = rewriter.create(op->getLoc(), input, dstElemTy); inputTy = input.getType().dyn_cast(); } @@ -522,6 +526,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); + llvm::sort(dims.begin(), dims.end()); auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); @@ -566,7 +571,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( mhloReduceOp.getResult(0), outShapeTensor); return success(); } - rewriter.replaceOp(op, mhloReduceOp.getResults()); + rewriter.replaceOpWithNewOp(op, outTy, + mhloReduceOp.getResults()); return success(); } } // namespace diff --git a/test/Conversion/TorchToMhlo/reduction.mlir b/test/Conversion/TorchToMhlo/reduction.mlir deleted file mode 100644 index 47dc0a476..000000000 --- a/test/Conversion/TorchToMhlo/reduction.mlir +++ /dev/null @@ -1,239 +0,0 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: func.func @torch.aten.max.dim$keepdim( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>) { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 -// CHECK: %[[T3:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor -// CHECK: %[[T4:.*]] = mhlo.constant dense<0> : tensor -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor -// CHECK: %[[T6:.*]]:2 = mhlo.reduce(%[[T0]] init: %[[T3]]), (%[[T5]] init: %[[T4]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) -// CHECK: reducer(%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) { -// CHECK: %[[T11:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T12:.*]] = mhlo.select %[[T11]], %[[ARG1]], %[[ARG3]] : tensor, tensor -// CHECK: %[[T13:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T14:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor -// CHECK: %[[T15:.*]] = mhlo.select %[[T11]], %[[ARG2]], %[[ARG4]] : tensor, tensor -// CHECK: %[[T16:.*]] = mhlo.select %[[T13]], %[[T14]], %[[T15]] : tensor, tensor -// CHECK: mhlo.return %[[T12]], %[[T16]] : tensor, tensor -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[FROM_ELEMENTS_1:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = mhlo.dynamic_reshape %[[T6]]#0, %[[FROM_ELEMENTS_1]] : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T8:.*]] = mhlo.dynamic_reshape %[[T6]]#1, %[[FROM_ELEMENTS_1]] : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,1],f32> -// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,1],si64> -// CHECK: return %[[T9]], %[[T10]] : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64> -func.func @torch.aten.max.dim$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>) { - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %values, %indices = torch.aten.max.dim %arg0, %int1, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64> - return %values, %indices : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.max.dim( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>) { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 -// CHECK: %[[T3:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor -// CHECK: %[[T4:.*]] = mhlo.constant dense<0> : tensor -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor -// CHECK: %[[T6:.*]]:2 = mhlo.reduce(%[[T0]] init: %[[T3]]), (%[[T5]] init: %[[T4]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) -// CHECK: reducer(%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) { -// CHECK: %[[T9:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T10:.*]] = mhlo.select %[[T9]], %[[ARG1]], %[[ARG3]] : tensor, tensor -// CHECK: %[[T11:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T12:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor -// CHECK: %[[T13:.*]] = mhlo.select %[[T9]], %[[ARG2]], %[[ARG4]] : tensor, tensor -// CHECK: %[[T14:.*]] = mhlo.select %[[T11]], %[[T12]], %[[T13]] : tensor, tensor -// CHECK: mhlo.return %[[T10]], %[[T14]] : tensor, tensor -// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]]#0 : tensor -> !torch.vtensor<[?],f32> -// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T6]]#1 : tensor -> !torch.vtensor<[?],si64> -// CHECK: return %[[T7]], %[[T8]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64> -func.func @torch.aten.max.dim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>) { - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %values, %indices = torch.aten.max.dim %arg0, %int1, %false : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64> - return %values, %indices : !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.argmax$keepdim( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],si64> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 -// CHECK: %[[T3:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor -// CHECK: %[[T4:.*]] = mhlo.constant dense<0> : tensor -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor -// CHECK: %[[T6:.*]]:2 = mhlo.reduce(%[[T0]] init: %[[T3]]), (%[[T5]] init: %[[T4]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) -// CHECK: reducer(%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) { -// CHECK: %[[T9:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T10:.*]] = mhlo.select %[[T9]], %[[ARG1]], %[[ARG3]] : tensor, tensor -// CHECK: %[[T11:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T12:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor -// CHECK: %[[T13:.*]] = mhlo.select %[[T9]], %[[ARG2]], %[[ARG4]] : tensor, tensor -// CHECK: %[[T14:.*]] = mhlo.select %[[T11]], %[[T12]], %[[T13]] : tensor, tensor -// CHECK: mhlo.return %[[T10]], %[[T14]] : tensor, tensor -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[FROM_ELEMENTS_1:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = mhlo.dynamic_reshape %[[T6]]#1, %[[FROM_ELEMENTS_1]] : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,1],si64> -// CHECK: return %[[T8]] : !torch.vtensor<[?,1],si64> -func.func @torch.aten.argmax$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],si64> { - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %indices = torch.aten.argmax %arg0, %int1, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,1],si64> - return %indices : !torch.vtensor<[?,1],si64> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.argmax( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],si64> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 -// CHECK: %[[T3:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor -// CHECK: %[[T4:.*]] = mhlo.constant dense<0> : tensor -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor -// CHECK: %[[T6:.*]]:2 = mhlo.reduce(%[[T0]] init: %[[T3]]), (%[[T5]] init: %[[T4]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) -// CHECK: reducer(%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) { -// CHECK: %[[T8:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T9:.*]] = mhlo.select %[[T8]], %[[ARG1]], %[[ARG3]] : tensor, tensor -// CHECK: %[[T10:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T11:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor -// CHECK: %[[T12:.*]] = mhlo.select %[[T8]], %[[ARG2]], %[[ARG4]] : tensor, tensor -// CHECK: %[[T13:.*]] = mhlo.select %[[T10]], %[[T11]], %[[T12]] : tensor, tensor -// CHECK: mhlo.return %[[T9]], %[[T13]] : tensor, tensor -// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]]#1 : tensor -> !torch.vtensor<[?],si64> -// CHECK: return %[[T7]] : !torch.vtensor<[?],si64> -func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],si64> { - %int1 = torch.constant.int 1 - %false = torch.constant.bool false - %indices = torch.aten.argmax %arg0, %int1, %false : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],si64> - return %indices : !torch.vtensor<[?],si64> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.sum.dim_Intlist$keepdim( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[1,1,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T3:.*]] = mhlo.reduce(%[[T0]] init: %[[T2]]) applies mhlo.add across dimensions = [0, 1] : (tensor, tensor) -> tensor -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T5:.*]] = arith.index_cast %[[DIM_0]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM_1]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[C1_I64]], %[[T6]] : tensor<3xi64> -// CHECK: %[[T7:.*]] = mhlo.dynamic_reshape %[[T3]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xi64>) -> tensor<1x1x?xf32> -// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<1x1x?xf32> -> !torch.vtensor<[1,1,?],f32> -// CHECK: return %[[T8]] : !torch.vtensor<[1,1,?],f32> -func.func @torch.aten.sum.dim_Intlist$keepdim(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[1,1,?],f32> { - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %true = torch.constant.bool true - %none = torch.constant.none - %0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.sum.dim_IntList %arg0, %0, %true, %none : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,?],f32> - return %1 : !torch.vtensor<[1,1,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.sum.dim_Intlist( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T3:.*]] = mhlo.reduce(%[[T0]] init: %[[T2]]) applies mhlo.add across dimensions = [0, 1] : (tensor, tensor) -> tensor -// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> -// CHECK: return %[[T4]] : !torch.vtensor<[?],f32> -func.func @torch.aten.sum.dim_Intlist(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?],f32> { - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %false = torch.constant.bool false - %none = torch.constant.none - %0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[?],f32> - return %1 : !torch.vtensor<[?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.sum( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T0]] init: %[[T1]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[],f32> -// CHECK: return %[[T3]] : !torch.vtensor<[],f32> -func.func @torch.aten.sum(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { - %none = torch.constant.none - %0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> - return %0 : !torch.vtensor<[],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.max( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor -// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T0]] init: %[[T1]]) applies mhlo.maximum across dimensions = [0, 1, 2] : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[],f32> -// CHECK: return %[[T3]] : !torch.vtensor<[],f32> -func.func @torch.aten.max(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { - %0 = torch.aten.max %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[],f32> - return %0 : !torch.vtensor<[],f32> -} -