mirror of https://github.com/llvm/torch-mlir
[TOSA] Extend Torch to TOSA reduction ops legalization (#3710)
- Add Torch to TOSA legalization for the following reduction ops: + aten.min.dim + aten.min + aten.max + aten.prod + aten.prod.dim_int + aten.all.dim - Add dtype casting support for reduce sum and prod ops - Extend aten.max.dim legalization to a template to support aten.min.dim legalization - Update end-to-end tests sets in xfail_sets.py Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: I854dd6c0c55e570c1fb7242f20c85cf64d6e7fe0 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3715/head
parent
d6cf718f10
commit
14ef05a292
|
@ -676,6 +676,53 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor type outputs permitted for reduce_mean");
|
||||
|
||||
auto selfElemTy = selfTy.getElementType();
|
||||
if (!selfElemTy.isIntOrFloat())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point or integer datatype legalization supported");
|
||||
|
||||
// TOSA ReduceAll and ReduceAny ops only accept bool input
|
||||
if constexpr (std::is_same<AtenOpT, AtenAllDimOp>() ||
|
||||
std::is_same<AtenOpT, AtenAnyDimOp>() ||
|
||||
std::is_same<AtenOpT, AtenAllOp>() ||
|
||||
std::is_same<AtenOpT, AtenAnyOp>()) {
|
||||
self = tosa::promoteType(
|
||||
rewriter, self,
|
||||
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)));
|
||||
}
|
||||
|
||||
// Handle dtype output and bool elem type for ReduceSum and ReduceProd ops
|
||||
if constexpr (std::is_same<AtenOpT, AtenSumDimIntListOp>() ||
|
||||
std::is_same<AtenOpT, AtenSumOp>() ||
|
||||
std::is_same<AtenOpT, AtenProdDimIntOp>() ||
|
||||
std::is_same<AtenOpT, AtenProdOp>()) {
|
||||
auto dtype = op.getDtype();
|
||||
int64_t dtypeInt;
|
||||
if (!isa<Torch::NoneType>(dtype.getType())) {
|
||||
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
return rewriter.notifyMatchFailure(op, "dtype is not a constant int");
|
||||
|
||||
FailureOr<Type> maybeDtypeType = getTypeForScalarType(
|
||||
op.getContext(), (torch_upstream::ScalarType)dtypeInt);
|
||||
if (failed(maybeDtypeType)) {
|
||||
return rewriter.notifyMatchFailure(op, "dtype is undefined");
|
||||
} else {
|
||||
Type dtypeType = maybeDtypeType.value();
|
||||
|
||||
if (isa<mlir::IntegerType>(dtypeType))
|
||||
dtypeType =
|
||||
rewriter.getIntegerType(dtypeType.getIntOrFloatBitWidth());
|
||||
|
||||
self = tosa::promoteType(
|
||||
rewriter, self,
|
||||
RankedTensorType::get(selfTy.getShape(), dtypeType));
|
||||
}
|
||||
} else {
|
||||
if (selfElemTy.isInteger(1))
|
||||
self = tosa::promoteType(rewriter, self, outputTy);
|
||||
}
|
||||
}
|
||||
|
||||
ElementsAttr reduceDimsAttr;
|
||||
bool keepDims;
|
||||
|
||||
|
@ -3248,17 +3295,23 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
|
||||
AtenMaxDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenMinMaxDimOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
auto indicesType =
|
||||
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType(1)));
|
||||
dyn_cast<TensorType>(typeConverter->convertType(op.getType(1)));
|
||||
if (!indicesType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
|
@ -3273,11 +3326,13 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
dim = toPositiveDim(dim, selfType.getRank());
|
||||
|
||||
if (!isValidDim(dim, selfType.getRank()))
|
||||
return rewriter.notifyMatchFailure(op, "dim must be less than tensor rank");
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"dim must be less than tensor rank");
|
||||
|
||||
bool keepDim;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
||||
return rewriter.notifyMatchFailure(op, "keepdim must be a Scalar constant");
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"keepdim must be a Scalar constant");
|
||||
|
||||
SmallVector<int64_t> reducedShape, prunedShape;
|
||||
for (auto en :
|
||||
|
@ -3293,36 +3348,51 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
|
||||
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);
|
||||
|
||||
Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
|
||||
Value reduceOp = rewriter.create<TosaOpT>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
|
||||
selfElemType),
|
||||
adaptor.getSelf(), dimAttr);
|
||||
self, dimAttr);
|
||||
|
||||
Value argMax = rewriter.create<tosa::ArgMaxOp>(
|
||||
// To handle ReduceMinDim indices, we apply ArgMaxOp on the negate
|
||||
// of the input tensor, which will return indices of input's min values
|
||||
Value argMaxOp;
|
||||
if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
|
||||
Value negateOp =
|
||||
rewriter.create<tosa::NegateOp>(op->getLoc(), selfType, self);
|
||||
|
||||
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
|
||||
indicesElemType),
|
||||
adaptor.getSelf(), dimAttr);
|
||||
negateOp, dimAttr);
|
||||
} else {
|
||||
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
|
||||
indicesElemType),
|
||||
self, dimAttr);
|
||||
}
|
||||
|
||||
if (argMax.getType() != indicesType) {
|
||||
argMax = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), indicesType, argMax,
|
||||
if (argMaxOp.getType() != indicesType) {
|
||||
argMaxOp = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), indicesType, argMaxOp,
|
||||
rewriter.getDenseI64ArrayAttr(reducedShape));
|
||||
}
|
||||
|
||||
if (!keepDim) {
|
||||
reduceMax = rewriter.create<tosa::ReshapeOp>(
|
||||
reduceOp = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
|
||||
selfElemType),
|
||||
reduceMax, prunedShapeAttr);
|
||||
reduceOp, prunedShapeAttr);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {reduceMax, argMax});
|
||||
rewriter.replaceOp(op, {reduceOp, argMaxOp});
|
||||
|
||||
return success();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||
|
@ -5623,6 +5693,10 @@ public:
|
|||
typeConverter, context);
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
|
||||
mlir::tosa::convertReduceAnyOp)
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp,
|
||||
mlir::tosa::convertReduceAllOp)
|
||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp,
|
||||
mlir::tosa::convertReduceProdOp)
|
||||
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
|
@ -5635,8 +5709,21 @@ public:
|
|||
mlir::tosa::convertReduceAnyOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
|
||||
mlir::tosa::convertReduceSumOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp,
|
||||
mlir::tosa::convertReduceMaxOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp,
|
||||
mlir::tosa::convertReduceMinOp)
|
||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
|
||||
mlir::tosa::convertReduceProdOp)
|
||||
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
|
||||
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
|
||||
#undef INSERT_INDICES_REDUCTION_OP_PATTERN
|
||||
|
||||
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
|
||||
|
@ -5727,7 +5814,6 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
|
||||
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||
|
|
|
@ -1625,6 +1625,7 @@ STABLEHLO_CRASHING_SET = {
|
|||
TOSA_CRASHING_SET = {
|
||||
# Runtime op verification: Out of bounds access
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"ReduceAllDimEmpty_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||
|
@ -1643,6 +1644,36 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"ArgminIntModule_basic",
|
||||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
"ArgminModule_keepDim",
|
||||
"ReduceAllDimBool_basic",
|
||||
"ReduceAllDimFloat_basic",
|
||||
"ReduceAllDimInt_basic",
|
||||
"ReduceAllFloatModule_basic",
|
||||
"ReduceAllIntModule_basic",
|
||||
"ReduceAnyFloatModule_basic",
|
||||
"ReduceAnyIntModule_basic",
|
||||
"ReduceMaxAllDims_basic",
|
||||
"ReduceMaxFloatModule_basic",
|
||||
"ReduceMaxSignedIntModule_basic",
|
||||
"ReduceMaxUnsignedIntModule_basic",
|
||||
"ReduceMinFloatModule_basic",
|
||||
"ReduceMinSignedIntModule_basic",
|
||||
"ReduceMinUnsignedIntModule_basic",
|
||||
"ReduceProdDtypeFloatModule_basic",
|
||||
"ReduceProdDtypeIntModule_basic",
|
||||
"ReduceProdElementTypeBoolModule_basic",
|
||||
"ReduceProdFloatModule_basic",
|
||||
"ReduceProdSignedIntModule_basic",
|
||||
"ReduceProdUnsignedIntModule_basic",
|
||||
"ReduceSumDimIntListDtypeFloatModule_basic",
|
||||
"ReduceSumDimIntListDtypeIntModule_basic",
|
||||
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
||||
"ReduceSumDtypeFloatModule_basic",
|
||||
"ReduceSumDtypeIntModule_basic",
|
||||
"ReduceSumElementTypeBoolModule_basic",
|
||||
"AtenTrilStaticModule_basic",
|
||||
"AtenTrilWithNegDiagonalStaticModule_basic",
|
||||
"AtenTrilWithPosDiagonalStaticModule_basic",
|
||||
|
@ -2155,6 +2186,39 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
TOSA_PASS_SET
|
||||
| {
|
||||
### Tests additionally passing in make_fx_tosa
|
||||
"ArgminIntModule_basic",
|
||||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
"ArgminModule_keepDim",
|
||||
"ReduceAllDimBool_basic",
|
||||
"ReduceAllDimFloat_basic",
|
||||
"ReduceAllDimInt_basic",
|
||||
"ReduceAllFloatModule_basic",
|
||||
"ReduceAllIntModule_basic",
|
||||
"ReduceAnyFloatModule_basic",
|
||||
"ReduceAnyIntModule_basic",
|
||||
"ReduceMaxAllDims_basic",
|
||||
"ReduceMaxFloatModule_basic",
|
||||
"ReduceMaxSignedIntModule_basic",
|
||||
"ReduceMaxUnsignedIntModule_basic",
|
||||
"ReduceMinFloatModule_basic",
|
||||
"ReduceMinSignedIntModule_basic",
|
||||
"ReduceMinUnsignedIntModule_basic",
|
||||
"ReduceProdDtypeFloatModule_basic",
|
||||
"ReduceProdDtypeIntModule_basic",
|
||||
"ReduceProdElementTypeBoolModule_basic",
|
||||
"ReduceProdFloatModule_basic",
|
||||
"ReduceProdSignedIntModule_basic",
|
||||
"ReduceProdUnsignedIntModule_basic",
|
||||
"ReduceSumDimIntListDtypeFloatModule_basic",
|
||||
"ReduceSumDimIntListDtypeIntModule_basic",
|
||||
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
||||
"ReduceSumDtypeFloatModule_basic",
|
||||
"ReduceSumDtypeIntModule_basic",
|
||||
"ReduceSumElementTypeBoolModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionMaskModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
"AvgPool2dCountIncludePadFalseStaticModule_basic",
|
||||
"AtenLinear1D_basic",
|
||||
"AtenLinearMatVec_basic",
|
||||
|
@ -3038,6 +3102,17 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"AtenPolarDoubleModule_basic",
|
||||
"AtenPolarFloatModule_basic",
|
||||
"HstackBasicComplexModule_basic",
|
||||
"HstackBasicFloatModule_basic",
|
||||
"HstackBasicIntFloatModule_basic",
|
||||
"HstackBasicIntModule_basic",
|
||||
"Rot90BasicModule_basic",
|
||||
"Rot90DynamicDimsModule_basic",
|
||||
"Rot90MultipleRotationsModule_basic",
|
||||
"Rot90NegativeEvenRotationsModule_basic",
|
||||
"Rot90NegativeOddRotationsModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AtenIntMM_basic",
|
||||
"AtenKthvalueDynamicDimsModule_basic",
|
||||
|
@ -3075,16 +3150,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"MultinomialModule2D_F32",
|
||||
"MultinomialModule2D_basic",
|
||||
"MultinomialModule_basic",
|
||||
"ReduceAminSingleDim_basic",
|
||||
"ReduceAminmaxAllDims_basic",
|
||||
"ReduceAminmaxSingleDim_basic",
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"RenormModuleFloat16_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScatterAddStaticModule_basic",
|
||||
"TensorsConcatComplex128FloatModule_basic",
|
||||
"TensorsConcatComplex128IntModule_basic",
|
||||
|
@ -3126,11 +3196,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AnyBoolFalseModule_basic",
|
||||
"AnyBoolTrueModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"ArgminIntModule_basic",
|
||||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
"ArgminModule_keepDim",
|
||||
"ArgminModule_with_dim",
|
||||
"AtenComplexImagModule_basic",
|
||||
"AtenComplexRealModule_basic",
|
||||
"AtenComplexViewModule_basic",
|
||||
|
@ -3239,7 +3304,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ConvolutionModule2DTranspose_basic",
|
||||
"CopyWithDifferentDTypesModule_basic",
|
||||
"CosineSimilarityStaticBroadcastModule_basic",
|
||||
"CrossEntropyLossModule_basic",
|
||||
"CumsumInputDtypeInt32Module_basic",
|
||||
"CumsumModule_basic",
|
||||
"CumsumStaticModule_basic",
|
||||
|
@ -3483,9 +3547,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"LinalgVectorNormComplexModule_basic",
|
||||
"LinspaceDtypeModule_basic",
|
||||
"LinspaceEmptyModule_basic",
|
||||
"LinspaceModule_basic",
|
||||
"LinspaceOneSizeModule_basic",
|
||||
"LinspaceTwoSizeModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
"MatmulBroadcastBatchDim_basic",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
|
@ -3524,10 +3586,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
|
||||
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
||||
"MaxPool3dWithIndicesStaticModule_basic",
|
||||
"MeanDimDtypeModule_basic",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
"MeanDimNoneDimModule_basic",
|
||||
"MeanDtypeModule_basic",
|
||||
"MseLossMeanReductionModule_basic",
|
||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||
"MulFloatModule_basic",
|
||||
|
@ -3566,9 +3626,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"NllLossModuleBackwardWeight_basic",
|
||||
"NllLossModuleBackward_basic",
|
||||
"NllLossModuleBackward_ignore_index",
|
||||
"NllLossModule_1D_basic",
|
||||
"NllLossModule_mean_basic",
|
||||
"NllLossModule_sum_basic",
|
||||
"NormScalarComplexModule_basic",
|
||||
"NormScalarModule_basic",
|
||||
"NormScalarOptDimKeepDimComplexModule_basic",
|
||||
|
@ -3613,14 +3670,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"RandnLikeDtypeModule_basic",
|
||||
"RandnLikeModule_basic",
|
||||
"RandnModule_basic",
|
||||
"ReduceAllDimBool_basic",
|
||||
"ReduceAllDimEmpty_basic",
|
||||
"ReduceAllDimFloat_basic",
|
||||
"ReduceAllDimInt_basic",
|
||||
"ReduceAllFloatModule_basic",
|
||||
"ReduceAllIntModule_basic",
|
||||
"ReduceAnyFloatModule_basic",
|
||||
"ReduceAnyIntModule_basic",
|
||||
"ReduceFrobeniusNormComplexModule_basic",
|
||||
"ReduceL1NormComplexModule_basic",
|
||||
"ReduceL1NormWithDTypeModule_basic",
|
||||
|
@ -3628,34 +3678,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ReduceL3NormAllDimsModule_basic",
|
||||
"ReduceL3NormKeepDimComplexModule_basic",
|
||||
"ReduceL3NormKeepDimModule_basic",
|
||||
"ReduceMaxAllDims_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMaxFloatModule_basic",
|
||||
"ReduceMaxSignedIntModule_basic",
|
||||
"ReduceMaxUnsignedIntModule_basic",
|
||||
"ReduceMinAlongDimNegative_basic",
|
||||
"ReduceMinAlongDimSignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDim_basic",
|
||||
"ReduceMinFloatModule_basic",
|
||||
"ReduceMinKeepDimReturnBoth_basic",
|
||||
"ReduceMinKeepDim_basic",
|
||||
"ReduceMinSignedIntModule_basic",
|
||||
"ReduceMinUnsignedIntModule_basic",
|
||||
"ReduceProdDimIntFloatModule_basic",
|
||||
"ReduceProdDtypeFloatModule_basic",
|
||||
"ReduceProdDtypeIntModule_basic",
|
||||
"ReduceProdElementTypeBoolModule_basic",
|
||||
"ReduceProdFloatModule_basic",
|
||||
"ReduceProdSignedIntModule_basic",
|
||||
"ReduceProdUnsignedIntModule_basic",
|
||||
"ReduceSumDimIntListDtypeFloatModule_basic",
|
||||
"ReduceSumDimIntListDtypeIntModule_basic",
|
||||
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
||||
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||
"ReduceSumDtypeFloatModule_basic",
|
||||
"ReduceSumDtypeIntModule_basic",
|
||||
"ReduceSumElementTypeBoolModule_basic",
|
||||
"ReflectionPad1dModule2dInput_Right",
|
||||
"ReflectionPad1dModule2dInput_basic",
|
||||
"ReflectionPad1dModule3dInput_Left",
|
||||
|
@ -3672,7 +3697,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ReplicationPad2dModule_top0",
|
||||
"RollModule_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"RsubIntModule_basic",
|
||||
"RsubIntModule_noalpha_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
|
@ -3801,6 +3825,17 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"HstackBasicComplexModule_basic",
|
||||
"HstackBasicFloatModule_basic",
|
||||
"HstackBasicIntFloatModule_basic",
|
||||
"HstackBasicIntModule_basic",
|
||||
"Rot90BasicModule_basic",
|
||||
"Rot90DynamicDimsModule_basic",
|
||||
"Rot90MultipleRotationsModule_basic",
|
||||
"Rot90NegativeEvenRotationsModule_basic",
|
||||
"Rot90NegativeOddRotationsModule_basic",
|
||||
"SafeSoftmaxModule_basic",
|
||||
"SafeSoftmaxNonNoneDtypeModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||
|
@ -3916,7 +3951,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ArgminIntModule_basic",
|
||||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
"ArgminModule_keepDim",
|
||||
"ArgminModule_with_dim",
|
||||
"AtenComplex64Module_basic",
|
||||
"AtenComplexImagModule_basic",
|
||||
|
@ -4162,7 +4196,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"ElementwiseFmodTensor_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseGeFloatIntScalarModule_basic",
|
||||
|
@ -4624,7 +4657,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ScalarImplicitIntModule_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScatterReduceFloatMaxModule",
|
||||
|
|
|
@ -1373,3 +1373,100 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v
|
|||
%0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32>
|
||||
return %0 : !torch.vtensor<[2,4],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.min.dim$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
%true = torch.constant.bool true
|
||||
%int2 = torch.constant.int 2
|
||||
%values, %indices = torch.aten.min.dim %0, %int2, %true : !torch.vtensor<[3,2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],f32>, !torch.vtensor<[3,2,1],si64>
|
||||
%1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
return %1 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.min$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reduce_min %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1>} : (tensor<1x1x1xf32>) -> tensor<1xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.min$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
|
||||
%0 = torch.aten.min %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32>
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.max$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.reduce_max %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reduce_max %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1>} : (tensor<1x1x1xf32>) -> tensor<1xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.max$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
|
||||
%0 = torch.aten.max %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32>
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.prod.dim_int$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.prod.dim_int$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> {
|
||||
%dim = torch.constant.int 2
|
||||
%keepdims = torch.constant.bool true
|
||||
%dtype = torch.constant.none
|
||||
%0 = torch.aten.prod.dim_int %arg0, %dim, %keepdims, %dtype: !torch.vtensor<[3,2,3],f32> , !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
|
||||
return %0 : !torch.vtensor<[3,2,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.all.dim$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],i1> -> tensor<3x2x3xi1>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reduce_all %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xi1>) -> tensor<3x2x1xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x2x1xi1> -> !torch.vtensor<[3,2,1],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,2,1],i1>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> {
|
||||
%dim = torch.constant.int 2
|
||||
%keepdims = torch.constant.bool true
|
||||
%0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1>
|
||||
return %0 : !torch.vtensor<[3,2,1],i1>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue