[TOSA] Extend Torch to TOSA reduction ops legalization (#3710)

- Add Torch to TOSA legalization for the following reduction ops:
  + aten.min.dim
  + aten.min
  + aten.max
  + aten.prod
  + aten.prod.dim_int
  + aten.all.dim
- Add dtype casting support for reduce sum and prod ops
- Extend aten.max.dim legalization to a template to support aten.min.dim
legalization
- Update end-to-end tests sets in xfail_sets.py

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I854dd6c0c55e570c1fb7242f20c85cf64d6e7fe0

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3715/head
justin-ngo-arm 2024-09-16 12:40:24 -07:00 committed by GitHub
parent d6cf718f10
commit 14ef05a292
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 331 additions and 116 deletions

View File

@ -676,6 +676,53 @@ public:
return rewriter.notifyMatchFailure(
op, "Only ranked tensor type outputs permitted for reduce_mean");
auto selfElemTy = selfTy.getElementType();
if (!selfElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");
// TOSA ReduceAll and ReduceAny ops only accept bool input
if constexpr (std::is_same<AtenOpT, AtenAllDimOp>() ||
std::is_same<AtenOpT, AtenAnyDimOp>() ||
std::is_same<AtenOpT, AtenAllOp>() ||
std::is_same<AtenOpT, AtenAnyOp>()) {
self = tosa::promoteType(
rewriter, self,
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)));
}
// Handle dtype output and bool elem type for ReduceSum and ReduceProd ops
if constexpr (std::is_same<AtenOpT, AtenSumDimIntListOp>() ||
std::is_same<AtenOpT, AtenSumOp>() ||
std::is_same<AtenOpT, AtenProdDimIntOp>() ||
std::is_same<AtenOpT, AtenProdOp>()) {
auto dtype = op.getDtype();
int64_t dtypeInt;
if (!isa<Torch::NoneType>(dtype.getType())) {
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(op, "dtype is not a constant int");
FailureOr<Type> maybeDtypeType = getTypeForScalarType(
op.getContext(), (torch_upstream::ScalarType)dtypeInt);
if (failed(maybeDtypeType)) {
return rewriter.notifyMatchFailure(op, "dtype is undefined");
} else {
Type dtypeType = maybeDtypeType.value();
if (isa<mlir::IntegerType>(dtypeType))
dtypeType =
rewriter.getIntegerType(dtypeType.getIntOrFloatBitWidth());
self = tosa::promoteType(
rewriter, self,
RankedTensorType::get(selfTy.getShape(), dtypeType));
}
} else {
if (selfElemTy.isInteger(1))
self = tosa::promoteType(rewriter, self, outputTy);
}
}
ElementsAttr reduceDimsAttr;
bool keepDims;
@ -3248,17 +3295,23 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
AtenMaxDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenMinMaxDimOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
auto self = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
const TypeConverter *typeConverter = this->getTypeConverter();
auto indicesType =
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType(1)));
dyn_cast<TensorType>(typeConverter->convertType(op.getType(1)));
if (!indicesType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
@ -3273,11 +3326,13 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
dim = toPositiveDim(dim, selfType.getRank());
if (!isValidDim(dim, selfType.getRank()))
return rewriter.notifyMatchFailure(op, "dim must be less than tensor rank");
return rewriter.notifyMatchFailure(op,
"dim must be less than tensor rank");
bool keepDim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(op, "keepdim must be a Scalar constant");
return rewriter.notifyMatchFailure(op,
"keepdim must be a Scalar constant");
SmallVector<int64_t> reducedShape, prunedShape;
for (auto en :
@ -3293,36 +3348,51 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);
Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
Value reduceOp = rewriter.create<TosaOpT>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
selfElemType),
adaptor.getSelf(), dimAttr);
self, dimAttr);
Value argMax = rewriter.create<tosa::ArgMaxOp>(
// To handle ReduceMinDim indices, we apply ArgMaxOp on the negate
// of the input tensor, which will return indices of input's min values
Value argMaxOp;
if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
Value negateOp =
rewriter.create<tosa::NegateOp>(op->getLoc(), selfType, self);
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
adaptor.getSelf(), dimAttr);
negateOp, dimAttr);
} else {
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
self, dimAttr);
}
if (argMax.getType() != indicesType) {
argMax = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesType, argMax,
if (argMaxOp.getType() != indicesType) {
argMaxOp = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesType, argMaxOp,
rewriter.getDenseI64ArrayAttr(reducedShape));
}
if (!keepDim) {
reduceMax = rewriter.create<tosa::ReshapeOp>(
reduceOp = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
selfElemType),
reduceMax, prunedShapeAttr);
reduceOp, prunedShapeAttr);
}
rewriter.replaceOp(op, {reduceMax, argMax});
rewriter.replaceOp(op, {reduceOp, argMaxOp});
return success();
}
}
};
template <>
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
@ -5623,6 +5693,10 @@ public:
typeConverter, context);
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
mlir::tosa::convertReduceAnyOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp,
mlir::tosa::convertReduceAllOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
@ -5635,8 +5709,21 @@ public:
mlir::tosa::convertReduceAnyOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
mlir::tosa::convertReduceSumOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp,
mlir::tosa::convertReduceMaxOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp,
mlir::tosa::convertReduceMinOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
#undef INSERT_INDICES_REDUCTION_OP_PATTERN
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
target.addIllegalOp<AtenOp>(); \
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
@ -5727,7 +5814,6 @@ public:
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);

View File

@ -1625,6 +1625,7 @@ STABLEHLO_CRASHING_SET = {
TOSA_CRASHING_SET = {
# Runtime op verification: Out of bounds access
"IndexTensorNegativeIndexModule_basic",
"ReduceAllDimEmpty_basic",
}
FX_IMPORTER_TOSA_CRASHING_SET = {
@ -1643,6 +1644,36 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ArgminIntModule_basic",
"ArgminIntModule_multiple_mins",
"ArgminModule_basic",
"ArgminModule_keepDim",
"ReduceAllDimBool_basic",
"ReduceAllDimFloat_basic",
"ReduceAllDimInt_basic",
"ReduceAllFloatModule_basic",
"ReduceAllIntModule_basic",
"ReduceAnyFloatModule_basic",
"ReduceAnyIntModule_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceMinFloatModule_basic",
"ReduceMinSignedIntModule_basic",
"ReduceMinUnsignedIntModule_basic",
"ReduceProdDtypeFloatModule_basic",
"ReduceProdDtypeIntModule_basic",
"ReduceProdElementTypeBoolModule_basic",
"ReduceProdFloatModule_basic",
"ReduceProdSignedIntModule_basic",
"ReduceProdUnsignedIntModule_basic",
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"AtenTrilStaticModule_basic",
"AtenTrilWithNegDiagonalStaticModule_basic",
"AtenTrilWithPosDiagonalStaticModule_basic",
@ -2155,6 +2186,39 @@ MAKE_FX_TOSA_PASS_SET = (
TOSA_PASS_SET
| {
### Tests additionally passing in make_fx_tosa
"ArgminIntModule_basic",
"ArgminIntModule_multiple_mins",
"ArgminModule_basic",
"ArgminModule_keepDim",
"ReduceAllDimBool_basic",
"ReduceAllDimFloat_basic",
"ReduceAllDimInt_basic",
"ReduceAllFloatModule_basic",
"ReduceAllIntModule_basic",
"ReduceAnyFloatModule_basic",
"ReduceAnyIntModule_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceMinFloatModule_basic",
"ReduceMinSignedIntModule_basic",
"ReduceMinUnsignedIntModule_basic",
"ReduceProdDtypeFloatModule_basic",
"ReduceProdDtypeIntModule_basic",
"ReduceProdElementTypeBoolModule_basic",
"ReduceProdFloatModule_basic",
"ReduceProdSignedIntModule_basic",
"ReduceProdUnsignedIntModule_basic",
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"AvgPool2dCountIncludePadFalseStaticModule_basic",
"AtenLinear1D_basic",
"AtenLinearMatVec_basic",
@ -3038,6 +3102,17 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
}
FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenPolarDoubleModule_basic",
"AtenPolarFloatModule_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
"Rot90BasicModule_basic",
"Rot90DynamicDimsModule_basic",
"Rot90MultipleRotationsModule_basic",
"Rot90NegativeEvenRotationsModule_basic",
"Rot90NegativeOddRotationsModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AtenIntMM_basic",
"AtenKthvalueDynamicDimsModule_basic",
@ -3075,16 +3150,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"MultinomialModule2D_F32",
"MultinomialModule2D_basic",
"MultinomialModule_basic",
"ReduceAminSingleDim_basic",
"ReduceAminmaxAllDims_basic",
"ReduceAminmaxSingleDim_basic",
"ReduceAnyDimFloatModule_basic",
"RenormModuleFloat16_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterAddStaticModule_basic",
"TensorsConcatComplex128FloatModule_basic",
"TensorsConcatComplex128IntModule_basic",
@ -3126,11 +3196,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AnyBoolFalseModule_basic",
"AnyBoolTrueModule_basic",
"ArangeStartOutViewModule_basic",
"ArgminIntModule_basic",
"ArgminIntModule_multiple_mins",
"ArgminModule_basic",
"ArgminModule_keepDim",
"ArgminModule_with_dim",
"AtenComplexImagModule_basic",
"AtenComplexRealModule_basic",
"AtenComplexViewModule_basic",
@ -3239,7 +3304,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ConvolutionModule2DTranspose_basic",
"CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"CrossEntropyLossModule_basic",
"CumsumInputDtypeInt32Module_basic",
"CumsumModule_basic",
"CumsumStaticModule_basic",
@ -3483,9 +3547,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"LinalgVectorNormComplexModule_basic",
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"MatmulBroadcastBatchDim_basic",
"MatmulStaticBroadcast_basic",
@ -3524,10 +3586,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
"MaxPool3dWithIndicesStaticModule_basic",
"MeanDimDtypeModule_basic",
"MeanDimEmptyDimModule_basic",
"MeanDimNoneDimModule_basic",
"MeanDtypeModule_basic",
"MseLossMeanReductionModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic",
"MulFloatModule_basic",
@ -3566,9 +3626,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"NllLossModuleBackwardWeight_basic",
"NllLossModuleBackward_basic",
"NllLossModuleBackward_ignore_index",
"NllLossModule_1D_basic",
"NllLossModule_mean_basic",
"NllLossModule_sum_basic",
"NormScalarComplexModule_basic",
"NormScalarModule_basic",
"NormScalarOptDimKeepDimComplexModule_basic",
@ -3613,14 +3670,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"RandnLikeDtypeModule_basic",
"RandnLikeModule_basic",
"RandnModule_basic",
"ReduceAllDimBool_basic",
"ReduceAllDimEmpty_basic",
"ReduceAllDimFloat_basic",
"ReduceAllDimInt_basic",
"ReduceAllFloatModule_basic",
"ReduceAllIntModule_basic",
"ReduceAnyFloatModule_basic",
"ReduceAnyIntModule_basic",
"ReduceFrobeniusNormComplexModule_basic",
"ReduceL1NormComplexModule_basic",
"ReduceL1NormWithDTypeModule_basic",
@ -3628,34 +3678,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReduceL3NormAllDimsModule_basic",
"ReduceL3NormKeepDimComplexModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceMinAlongDimNegative_basic",
"ReduceMinAlongDimSignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ReduceMinAlongDim_basic",
"ReduceMinFloatModule_basic",
"ReduceMinKeepDimReturnBoth_basic",
"ReduceMinKeepDim_basic",
"ReduceMinSignedIntModule_basic",
"ReduceMinUnsignedIntModule_basic",
"ReduceProdDimIntFloatModule_basic",
"ReduceProdDtypeFloatModule_basic",
"ReduceProdDtypeIntModule_basic",
"ReduceProdElementTypeBoolModule_basic",
"ReduceProdFloatModule_basic",
"ReduceProdSignedIntModule_basic",
"ReduceProdUnsignedIntModule_basic",
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
"ReflectionPad1dModule3dInput_Left",
@ -3672,7 +3697,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReplicationPad2dModule_top0",
"RollModule_basic",
"RsubInt0d_NumToTensor_Module_basic",
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
@ -3801,6 +3825,17 @@ ONNX_TOSA_CRASHING_SET = {
}
ONNX_TOSA_XFAIL_SET = {
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
"Rot90BasicModule_basic",
"Rot90DynamicDimsModule_basic",
"Rot90MultipleRotationsModule_basic",
"Rot90NegativeEvenRotationsModule_basic",
"Rot90NegativeOddRotationsModule_basic",
"SafeSoftmaxModule_basic",
"SafeSoftmaxNonNoneDtypeModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
@ -3916,7 +3951,6 @@ ONNX_TOSA_XFAIL_SET = {
"ArgminIntModule_basic",
"ArgminIntModule_multiple_mins",
"ArgminModule_basic",
"ArgminModule_keepDim",
"ArgminModule_with_dim",
"AtenComplex64Module_basic",
"AtenComplexImagModule_basic",
@ -4162,7 +4196,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseExpm1Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseGeFloatIntScalarModule_basic",
@ -4624,7 +4657,6 @@ ONNX_TOSA_XFAIL_SET = {
"ScalarImplicitIntModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterReduceFloatMaxModule",

View File

@ -1373,3 +1373,100 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v
%0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32>
return %0 : !torch.vtensor<[2,4],si32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.min.dim$basic(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
// CHECK: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32>
// CHECK: }
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
%true = torch.constant.bool true
%int2 = torch.constant.int 2
%values, %indices = torch.aten.min.dim %0, %int2, %true : !torch.vtensor<[3,2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],f32>, !torch.vtensor<[3,2,1],si64>
%1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
return %1 : tensor<3x2x1xf32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.min$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reduce_min %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1>} : (tensor<1x1x1xf32>) -> tensor<1xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32>
// CHECK: }
func.func @torch.aten.min$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
%0 = torch.aten.min %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.max$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reduce_max %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reduce_max %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1>} : (tensor<1x1x1xf32>) -> tensor<1xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32>
// CHECK: }
func.func @torch.aten.max$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
%0 = torch.aten.max %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.prod.dim_int$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
// CHECK: %[[VAL_4:.*]] = torch.constant.none
// CHECK: %[[VAL_5:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,2,1],f32>
// CHECK: }
func.func @torch.aten.prod.dim_int$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> {
%dim = torch.constant.int 2
%keepdims = torch.constant.bool true
%dtype = torch.constant.none
%0 = torch.aten.prod.dim_int %arg0, %dim, %keepdims, %dtype: !torch.vtensor<[3,2,3],f32> , !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
return %0 : !torch.vtensor<[3,2,1],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.all.dim$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],i1> -> tensor<3x2x3xi1>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
// CHECK: %[[VAL_4:.*]] = tosa.reduce_all %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xi1>) -> tensor<3x2x1xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x2x1xi1> -> !torch.vtensor<[3,2,1],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,2,1],i1>
// CHECK: }
func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> {
%dim = torch.constant.int 2
%keepdims = torch.constant.bool true
%0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1>
return %0 : !torch.vtensor<[3,2,1],i1>
}