mirror of https://github.com/llvm/torch-mlir
fix #1626 return type mismatch
parent
68f568b704
commit
b7022655dc
|
@ -115,10 +115,18 @@ MHLO_PASS_SET = {
|
|||
"MatmulSingleDynamicBatchDim_basic",
|
||||
"Matmul_3d",
|
||||
"Matmul_4d",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
"MeanDtypeModule_basic",
|
||||
"MeanDynamicSizesModule_basic",
|
||||
"MeanLargeInputModule_basic",
|
||||
"MeanModule_basic",
|
||||
"MmTanhModule_basic",
|
||||
"Mv_basic",
|
||||
"PrimsConvertElementTypeModule_basic",
|
||||
"ReduceFrobeniusNormKeepDimModule_basic",
|
||||
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
||||
"ReduceSumElementTypeBoolModule_basic",
|
||||
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||
"ReduceSumDimIntListDtypeFloatModule_basic",
|
||||
"ReduceSumDimIntListDtypeIntModule_basic",
|
||||
"ReduceSumDimIntListKeepDimFloatModule_basic",
|
||||
|
@ -136,7 +144,6 @@ MHLO_PASS_SET = {
|
|||
"LiftFreshCopyModule_basic",
|
||||
"Mlp2LayerModuleNoBias_basic",
|
||||
"NumelModule_basic",
|
||||
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||
"SqueezeModule_allUnitDim",
|
||||
"SqueezeDimModule_unitDim",
|
||||
"ViewCollapseOnesMiddleModule_basic",
|
||||
|
@ -149,9 +156,6 @@ MHLO_PASS_SET = {
|
|||
"ViewTwoToThreeStaticModule_basic",
|
||||
"ViewExpandOnesMiddleOppModule_basic",
|
||||
"ViewOffsetBackwardTestStaticModule_basic",
|
||||
"MeanModule_basic",
|
||||
"MeanDynamicSizesModule_basic",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"AtenToDeviceModule_basic",
|
||||
"AvgPool2dStaticModule_basic",
|
||||
|
|
|
@ -347,15 +347,15 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
}
|
||||
auto dtype = adaptor.dtype();
|
||||
if (!dtype.getType().isa<Torch::NoneType>()) {
|
||||
auto dstElemTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (inputTy.getElementType() != outTy.getElementType()) {
|
||||
// Use output element type as computation type.
|
||||
auto dstElemTy = outTy.getElementType();
|
||||
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
|
@ -376,11 +376,11 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
for (int64_t i = 0; i < inputTy.getRank(); i++) {
|
||||
dims.push_back(i);
|
||||
}
|
||||
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue) return failure();
|
||||
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
|
||||
|
@ -401,7 +401,8 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, mhloReduceOp.getResults());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
||||
mhloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -438,6 +439,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue) return failure();
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
|
||||
|
@ -458,7 +460,9 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, mhloReduceOp.getResults());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
mhloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -471,15 +475,15 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
}
|
||||
auto dtype = adaptor.dtype();
|
||||
if (!dtype.getType().isa<Torch::NoneType>()) {
|
||||
auto dstElemTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (inputTy.getElementType() != outTy.getElementType()) {
|
||||
// Use output element type as computation type.
|
||||
auto dstElemTy = outTy.getElementType();
|
||||
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
|
@ -522,6 +526,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue) return failure();
|
||||
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
|
||||
|
@ -566,7 +571,8 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
mhloReduceOp.getResult(0), outShapeTensor);
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOp(op, mhloReduceOp.getResults());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
||||
mhloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -1,239 +0,0 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.max.dim$keepdim(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>) {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[T3:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T6:.*]]:2 = mhlo.reduce(%[[T0]] init: %[[T3]]), (%[[T5]] init: %[[T4]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
// CHECK: reducer(%[[ARG1:.*]]: tensor<f32>, %[[ARG3:.*]]: tensor<f32>) (%[[ARG2:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i64>) {
|
||||
// CHECK: %[[T11:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T12:.*]] = mhlo.select %[[T11]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
||||
// CHECK: %[[T13:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T14:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
|
||||
// CHECK: %[[T15:.*]] = mhlo.select %[[T11]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: %[[T16:.*]] = mhlo.select %[[T13]], %[[T14]], %[[T15]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: mhlo.return %[[T12]], %[[T16]] : tensor<f32>, tensor<i64>
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FROM_ELEMENTS_1:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = mhlo.dynamic_reshape %[[T6]]#0, %[[FROM_ELEMENTS_1]] : (tensor<?xf32>, tensor<2xi64>) -> tensor<?x1xf32>
|
||||
// CHECK: %[[T8:.*]] = mhlo.dynamic_reshape %[[T6]]#1, %[[FROM_ELEMENTS_1]] : (tensor<?xi64>, tensor<2xi64>) -> tensor<?x1xi64>
|
||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x1xf32> -> !torch.vtensor<[?,1],f32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64>
|
||||
// CHECK: return %[[T9]], %[[T10]] : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>
|
||||
func.func @torch.aten.max.dim$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>) {
|
||||
%true = torch.constant.bool true
|
||||
%int1 = torch.constant.int 1
|
||||
%values, %indices = torch.aten.max.dim %arg0, %int1, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>
|
||||
return %values, %indices : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,1],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.max.dim(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>) {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[T3:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T6:.*]]:2 = mhlo.reduce(%[[T0]] init: %[[T3]]), (%[[T5]] init: %[[T4]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
// CHECK: reducer(%[[ARG1:.*]]: tensor<f32>, %[[ARG3:.*]]: tensor<f32>) (%[[ARG2:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i64>) {
|
||||
// CHECK: %[[T9:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T10:.*]] = mhlo.select %[[T9]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
||||
// CHECK: %[[T11:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T12:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
|
||||
// CHECK: %[[T13:.*]] = mhlo.select %[[T9]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: %[[T14:.*]] = mhlo.select %[[T11]], %[[T12]], %[[T13]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: mhlo.return %[[T10]], %[[T14]] : tensor<f32>, tensor<i64>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]]#0 : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T6]]#1 : tensor<?xi64> -> !torch.vtensor<[?],si64>
|
||||
// CHECK: return %[[T7]], %[[T8]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>
|
||||
func.func @torch.aten.max.dim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>) {
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%values, %indices = torch.aten.max.dim %arg0, %int1, %false : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>
|
||||
return %values, %indices : !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.argmax$keepdim(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],si64> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[T3:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T6:.*]]:2 = mhlo.reduce(%[[T0]] init: %[[T3]]), (%[[T5]] init: %[[T4]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
// CHECK: reducer(%[[ARG1:.*]]: tensor<f32>, %[[ARG3:.*]]: tensor<f32>) (%[[ARG2:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i64>) {
|
||||
// CHECK: %[[T9:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T10:.*]] = mhlo.select %[[T9]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
||||
// CHECK: %[[T11:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T12:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
|
||||
// CHECK: %[[T13:.*]] = mhlo.select %[[T9]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: %[[T14:.*]] = mhlo.select %[[T11]], %[[T12]], %[[T13]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: mhlo.return %[[T10]], %[[T14]] : tensor<f32>, tensor<i64>
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FROM_ELEMENTS_1:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = mhlo.dynamic_reshape %[[T6]]#1, %[[FROM_ELEMENTS_1]] : (tensor<?xi64>, tensor<2xi64>) -> tensor<?x1xi64>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64>
|
||||
// CHECK: return %[[T8]] : !torch.vtensor<[?,1],si64>
|
||||
func.func @torch.aten.argmax$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%true = torch.constant.bool true
|
||||
%indices = torch.aten.argmax %arg0, %int1, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,1],si64>
|
||||
return %indices : !torch.vtensor<[?,1],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.argmax(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],si64> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[T3:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T6:.*]]:2 = mhlo.reduce(%[[T0]] init: %[[T3]]), (%[[T5]] init: %[[T4]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
// CHECK: reducer(%[[ARG1:.*]]: tensor<f32>, %[[ARG3:.*]]: tensor<f32>) (%[[ARG2:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i64>) {
|
||||
// CHECK: %[[T8:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T9:.*]] = mhlo.select %[[T8]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
||||
// CHECK: %[[T10:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T11:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
|
||||
// CHECK: %[[T12:.*]] = mhlo.select %[[T8]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: %[[T13:.*]] = mhlo.select %[[T10]], %[[T11]], %[[T12]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: mhlo.return %[[T9]], %[[T13]] : tensor<f32>, tensor<i64>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]]#1 : tensor<?xi64> -> !torch.vtensor<[?],si64>
|
||||
// CHECK: return %[[T7]] : !torch.vtensor<[?],si64>
|
||||
func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%false = torch.constant.bool false
|
||||
%indices = torch.aten.argmax %arg0, %int1, %false : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],si64>
|
||||
return %indices : !torch.vtensor<[?],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sum.dim_Intlist$keepdim(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[1,1,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reduce(%[[T0]] init: %[[T2]]) applies mhlo.add across dimensions = [0, 1] : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[DIM_0]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[C1_I64]], %[[T6]] : tensor<3xi64>
|
||||
// CHECK: %[[T7:.*]] = mhlo.dynamic_reshape %[[T3]], %[[FROM_ELEMENTS]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<1x1x?xf32>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<1x1x?xf32> -> !torch.vtensor<[1,1,?],f32>
|
||||
// CHECK: return %[[T8]] : !torch.vtensor<[1,1,?],f32>
|
||||
func.func @torch.aten.sum.dim_Intlist$keepdim(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[1,1,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.sum.dim_IntList %arg0, %0, %true, %none : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,?],f32>
|
||||
return %1 : !torch.vtensor<[1,1,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sum.dim_Intlist(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reduce(%[[T0]] init: %[[T2]]) applies mhlo.add across dimensions = [0, 1] : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
||||
func.func @torch.aten.sum.dim_Intlist(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%false = torch.constant.bool false
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?],f32>
|
||||
return %1 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sum(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T0]] init: %[[T1]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.sum(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.max(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T0]] init: %[[T1]]) applies mhlo.maximum across dimensions = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.max(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%0 = torch.aten.max %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue