mirror of https://github.com/llvm/torch-mlir
[TOSA] Extend Torch to TOSA reduction ops legalization (#3710)
- Add Torch to TOSA legalization for the following reduction ops: + aten.min.dim + aten.min + aten.max + aten.prod + aten.prod.dim_int + aten.all.dim - Add dtype casting support for reduce sum and prod ops - Extend aten.max.dim legalization to a template to support aten.min.dim legalization - Update end-to-end tests sets in xfail_sets.py Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: I854dd6c0c55e570c1fb7242f20c85cf64d6e7fe0 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3715/head
parent
d6cf718f10
commit
14ef05a292
|
@ -676,6 +676,53 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor type outputs permitted for reduce_mean");
|
op, "Only ranked tensor type outputs permitted for reduce_mean");
|
||||||
|
|
||||||
|
auto selfElemTy = selfTy.getElementType();
|
||||||
|
if (!selfElemTy.isIntOrFloat())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
|
|
||||||
|
// TOSA ReduceAll and ReduceAny ops only accept bool input
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenAllDimOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenAnyDimOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenAllOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenAnyOp>()) {
|
||||||
|
self = tosa::promoteType(
|
||||||
|
rewriter, self,
|
||||||
|
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle dtype output and bool elem type for ReduceSum and ReduceProd ops
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenSumDimIntListOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenSumOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenProdDimIntOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenProdOp>()) {
|
||||||
|
auto dtype = op.getDtype();
|
||||||
|
int64_t dtypeInt;
|
||||||
|
if (!isa<Torch::NoneType>(dtype.getType())) {
|
||||||
|
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dtype is not a constant int");
|
||||||
|
|
||||||
|
FailureOr<Type> maybeDtypeType = getTypeForScalarType(
|
||||||
|
op.getContext(), (torch_upstream::ScalarType)dtypeInt);
|
||||||
|
if (failed(maybeDtypeType)) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dtype is undefined");
|
||||||
|
} else {
|
||||||
|
Type dtypeType = maybeDtypeType.value();
|
||||||
|
|
||||||
|
if (isa<mlir::IntegerType>(dtypeType))
|
||||||
|
dtypeType =
|
||||||
|
rewriter.getIntegerType(dtypeType.getIntOrFloatBitWidth());
|
||||||
|
|
||||||
|
self = tosa::promoteType(
|
||||||
|
rewriter, self,
|
||||||
|
RankedTensorType::get(selfTy.getShape(), dtypeType));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (selfElemTy.isInteger(1))
|
||||||
|
self = tosa::promoteType(rewriter, self, outputTy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ElementsAttr reduceDimsAttr;
|
ElementsAttr reduceDimsAttr;
|
||||||
bool keepDims;
|
bool keepDims;
|
||||||
|
|
||||||
|
@ -3248,81 +3295,104 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <typename AtenOpT, typename TosaOpT>
|
||||||
LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
|
class ConvertAtenMinMaxDimOp : public OpConversionPattern<AtenOpT> {
|
||||||
AtenMaxDimOp op, OpAdaptor adaptor,
|
public:
|
||||||
ConversionPatternRewriter &rewriter) const {
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
auto self = adaptor.getSelf();
|
||||||
if (!selfType)
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
auto indicesType =
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType(1)));
|
auto indicesType =
|
||||||
if (!indicesType)
|
dyn_cast<TensorType>(typeConverter->convertType(op.getType(1)));
|
||||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
if (!indicesType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
auto selfElemType = selfType.getElementType();
|
auto selfElemType = selfType.getElementType();
|
||||||
auto indicesElemType = indicesType.getElementType();
|
auto indicesElemType = indicesType.getElementType();
|
||||||
|
|
||||||
// Only statically deducible values are currently supported
|
// Only statically deducible values are currently supported
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");
|
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");
|
||||||
|
|
||||||
dim = toPositiveDim(dim, selfType.getRank());
|
dim = toPositiveDim(dim, selfType.getRank());
|
||||||
|
|
||||||
if (!isValidDim(dim, selfType.getRank()))
|
if (!isValidDim(dim, selfType.getRank()))
|
||||||
return rewriter.notifyMatchFailure(op, "dim must be less than tensor rank");
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"dim must be less than tensor rank");
|
||||||
|
|
||||||
bool keepDim;
|
bool keepDim;
|
||||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
||||||
return rewriter.notifyMatchFailure(op, "keepdim must be a Scalar constant");
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"keepdim must be a Scalar constant");
|
||||||
|
|
||||||
SmallVector<int64_t> reducedShape, prunedShape;
|
SmallVector<int64_t> reducedShape, prunedShape;
|
||||||
for (auto en :
|
for (auto en :
|
||||||
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
|
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
|
||||||
if (static_cast<int64_t>(en.index()) == dim) {
|
if (static_cast<int64_t>(en.index()) == dim) {
|
||||||
reducedShape.push_back(1);
|
reducedShape.push_back(1);
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
reducedShape.push_back(en.value());
|
||||||
|
prunedShape.push_back(en.value());
|
||||||
}
|
}
|
||||||
reducedShape.push_back(en.value());
|
|
||||||
prunedShape.push_back(en.value());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
|
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
|
||||||
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);
|
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);
|
||||||
|
|
||||||
Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
|
Value reduceOp = rewriter.create<TosaOpT>(
|
||||||
op->getLoc(),
|
|
||||||
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
|
|
||||||
selfElemType),
|
|
||||||
adaptor.getSelf(), dimAttr);
|
|
||||||
|
|
||||||
Value argMax = rewriter.create<tosa::ArgMaxOp>(
|
|
||||||
op->getLoc(),
|
|
||||||
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
|
|
||||||
indicesElemType),
|
|
||||||
adaptor.getSelf(), dimAttr);
|
|
||||||
|
|
||||||
if (argMax.getType() != indicesType) {
|
|
||||||
argMax = rewriter.create<tosa::ReshapeOp>(
|
|
||||||
op->getLoc(), indicesType, argMax,
|
|
||||||
rewriter.getDenseI64ArrayAttr(reducedShape));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!keepDim) {
|
|
||||||
reduceMax = rewriter.create<tosa::ReshapeOp>(
|
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
|
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
|
||||||
selfElemType),
|
selfElemType),
|
||||||
reduceMax, prunedShapeAttr);
|
self, dimAttr);
|
||||||
|
|
||||||
|
// To handle ReduceMinDim indices, we apply ArgMaxOp on the negate
|
||||||
|
// of the input tensor, which will return indices of input's min values
|
||||||
|
Value argMaxOp;
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
|
||||||
|
Value negateOp =
|
||||||
|
rewriter.create<tosa::NegateOp>(op->getLoc(), selfType, self);
|
||||||
|
|
||||||
|
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
|
||||||
|
indicesElemType),
|
||||||
|
negateOp, dimAttr);
|
||||||
|
} else {
|
||||||
|
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
|
||||||
|
indicesElemType),
|
||||||
|
self, dimAttr);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argMaxOp.getType() != indicesType) {
|
||||||
|
argMaxOp = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(), indicesType, argMaxOp,
|
||||||
|
rewriter.getDenseI64ArrayAttr(reducedShape));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!keepDim) {
|
||||||
|
reduceOp = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
|
||||||
|
selfElemType),
|
||||||
|
reduceOp, prunedShapeAttr);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {reduceOp, argMaxOp});
|
||||||
|
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
};
|
||||||
rewriter.replaceOp(op, {reduceMax, argMax});
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||||
|
@ -5623,6 +5693,10 @@ public:
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
|
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
|
||||||
mlir::tosa::convertReduceAnyOp)
|
mlir::tosa::convertReduceAnyOp)
|
||||||
|
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp,
|
||||||
|
mlir::tosa::convertReduceAllOp)
|
||||||
|
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp,
|
||||||
|
mlir::tosa::convertReduceProdOp)
|
||||||
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
|
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN
|
||||||
|
|
||||||
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||||
|
@ -5635,8 +5709,21 @@ public:
|
||||||
mlir::tosa::convertReduceAnyOp)
|
mlir::tosa::convertReduceAnyOp)
|
||||||
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
|
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
|
||||||
mlir::tosa::convertReduceSumOp)
|
mlir::tosa::convertReduceSumOp)
|
||||||
|
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp,
|
||||||
|
mlir::tosa::convertReduceMaxOp)
|
||||||
|
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp,
|
||||||
|
mlir::tosa::convertReduceMinOp)
|
||||||
|
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
|
||||||
|
mlir::tosa::convertReduceProdOp)
|
||||||
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
|
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||||
|
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
|
||||||
|
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
|
||||||
|
#undef INSERT_INDICES_REDUCTION_OP_PATTERN
|
||||||
|
|
||||||
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
|
#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
|
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
|
||||||
|
@ -5727,7 +5814,6 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
|
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
|
|
||||||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||||
|
|
|
@ -1625,6 +1625,7 @@ STABLEHLO_CRASHING_SET = {
|
||||||
TOSA_CRASHING_SET = {
|
TOSA_CRASHING_SET = {
|
||||||
# Runtime op verification: Out of bounds access
|
# Runtime op verification: Out of bounds access
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
|
"ReduceAllDimEmpty_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_CRASHING_SET = {
|
FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
|
@ -1643,6 +1644,36 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"ArgminIntModule_basic",
|
||||||
|
"ArgminIntModule_multiple_mins",
|
||||||
|
"ArgminModule_basic",
|
||||||
|
"ArgminModule_keepDim",
|
||||||
|
"ReduceAllDimBool_basic",
|
||||||
|
"ReduceAllDimFloat_basic",
|
||||||
|
"ReduceAllDimInt_basic",
|
||||||
|
"ReduceAllFloatModule_basic",
|
||||||
|
"ReduceAllIntModule_basic",
|
||||||
|
"ReduceAnyFloatModule_basic",
|
||||||
|
"ReduceAnyIntModule_basic",
|
||||||
|
"ReduceMaxAllDims_basic",
|
||||||
|
"ReduceMaxFloatModule_basic",
|
||||||
|
"ReduceMaxSignedIntModule_basic",
|
||||||
|
"ReduceMaxUnsignedIntModule_basic",
|
||||||
|
"ReduceMinFloatModule_basic",
|
||||||
|
"ReduceMinSignedIntModule_basic",
|
||||||
|
"ReduceMinUnsignedIntModule_basic",
|
||||||
|
"ReduceProdDtypeFloatModule_basic",
|
||||||
|
"ReduceProdDtypeIntModule_basic",
|
||||||
|
"ReduceProdElementTypeBoolModule_basic",
|
||||||
|
"ReduceProdFloatModule_basic",
|
||||||
|
"ReduceProdSignedIntModule_basic",
|
||||||
|
"ReduceProdUnsignedIntModule_basic",
|
||||||
|
"ReduceSumDimIntListDtypeFloatModule_basic",
|
||||||
|
"ReduceSumDimIntListDtypeIntModule_basic",
|
||||||
|
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
||||||
|
"ReduceSumDtypeFloatModule_basic",
|
||||||
|
"ReduceSumDtypeIntModule_basic",
|
||||||
|
"ReduceSumElementTypeBoolModule_basic",
|
||||||
"AtenTrilStaticModule_basic",
|
"AtenTrilStaticModule_basic",
|
||||||
"AtenTrilWithNegDiagonalStaticModule_basic",
|
"AtenTrilWithNegDiagonalStaticModule_basic",
|
||||||
"AtenTrilWithPosDiagonalStaticModule_basic",
|
"AtenTrilWithPosDiagonalStaticModule_basic",
|
||||||
|
@ -2155,6 +2186,39 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
TOSA_PASS_SET
|
TOSA_PASS_SET
|
||||||
| {
|
| {
|
||||||
### Tests additionally passing in make_fx_tosa
|
### Tests additionally passing in make_fx_tosa
|
||||||
|
"ArgminIntModule_basic",
|
||||||
|
"ArgminIntModule_multiple_mins",
|
||||||
|
"ArgminModule_basic",
|
||||||
|
"ArgminModule_keepDim",
|
||||||
|
"ReduceAllDimBool_basic",
|
||||||
|
"ReduceAllDimFloat_basic",
|
||||||
|
"ReduceAllDimInt_basic",
|
||||||
|
"ReduceAllFloatModule_basic",
|
||||||
|
"ReduceAllIntModule_basic",
|
||||||
|
"ReduceAnyFloatModule_basic",
|
||||||
|
"ReduceAnyIntModule_basic",
|
||||||
|
"ReduceMaxAllDims_basic",
|
||||||
|
"ReduceMaxFloatModule_basic",
|
||||||
|
"ReduceMaxSignedIntModule_basic",
|
||||||
|
"ReduceMaxUnsignedIntModule_basic",
|
||||||
|
"ReduceMinFloatModule_basic",
|
||||||
|
"ReduceMinSignedIntModule_basic",
|
||||||
|
"ReduceMinUnsignedIntModule_basic",
|
||||||
|
"ReduceProdDtypeFloatModule_basic",
|
||||||
|
"ReduceProdDtypeIntModule_basic",
|
||||||
|
"ReduceProdElementTypeBoolModule_basic",
|
||||||
|
"ReduceProdFloatModule_basic",
|
||||||
|
"ReduceProdSignedIntModule_basic",
|
||||||
|
"ReduceProdUnsignedIntModule_basic",
|
||||||
|
"ReduceSumDimIntListDtypeFloatModule_basic",
|
||||||
|
"ReduceSumDimIntListDtypeIntModule_basic",
|
||||||
|
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
||||||
|
"ReduceSumDtypeFloatModule_basic",
|
||||||
|
"ReduceSumDtypeIntModule_basic",
|
||||||
|
"ReduceSumElementTypeBoolModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
"ScaledDotProductAttentionMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameModule_basic",
|
||||||
"AvgPool2dCountIncludePadFalseStaticModule_basic",
|
"AvgPool2dCountIncludePadFalseStaticModule_basic",
|
||||||
"AtenLinear1D_basic",
|
"AtenLinear1D_basic",
|
||||||
"AtenLinearMatVec_basic",
|
"AtenLinearMatVec_basic",
|
||||||
|
@ -3038,6 +3102,17 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"AtenPolarDoubleModule_basic",
|
||||||
|
"AtenPolarFloatModule_basic",
|
||||||
|
"HstackBasicComplexModule_basic",
|
||||||
|
"HstackBasicFloatModule_basic",
|
||||||
|
"HstackBasicIntFloatModule_basic",
|
||||||
|
"HstackBasicIntModule_basic",
|
||||||
|
"Rot90BasicModule_basic",
|
||||||
|
"Rot90DynamicDimsModule_basic",
|
||||||
|
"Rot90MultipleRotationsModule_basic",
|
||||||
|
"Rot90NegativeEvenRotationsModule_basic",
|
||||||
|
"Rot90NegativeOddRotationsModule_basic",
|
||||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||||
"AtenIntMM_basic",
|
"AtenIntMM_basic",
|
||||||
"AtenKthvalueDynamicDimsModule_basic",
|
"AtenKthvalueDynamicDimsModule_basic",
|
||||||
|
@ -3075,16 +3150,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"MultinomialModule2D_F32",
|
"MultinomialModule2D_F32",
|
||||||
"MultinomialModule2D_basic",
|
"MultinomialModule2D_basic",
|
||||||
"MultinomialModule_basic",
|
"MultinomialModule_basic",
|
||||||
"ReduceAminSingleDim_basic",
|
|
||||||
"ReduceAminmaxAllDims_basic",
|
|
||||||
"ReduceAminmaxSingleDim_basic",
|
|
||||||
"ReduceAnyDimFloatModule_basic",
|
|
||||||
"RenormModuleFloat16_basic",
|
"RenormModuleFloat16_basic",
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
|
||||||
"ScatterAddStaticModule_basic",
|
"ScatterAddStaticModule_basic",
|
||||||
"TensorsConcatComplex128FloatModule_basic",
|
"TensorsConcatComplex128FloatModule_basic",
|
||||||
"TensorsConcatComplex128IntModule_basic",
|
"TensorsConcatComplex128IntModule_basic",
|
||||||
|
@ -3126,11 +3196,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AnyBoolFalseModule_basic",
|
"AnyBoolFalseModule_basic",
|
||||||
"AnyBoolTrueModule_basic",
|
"AnyBoolTrueModule_basic",
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
"ArgminIntModule_basic",
|
|
||||||
"ArgminIntModule_multiple_mins",
|
|
||||||
"ArgminModule_basic",
|
|
||||||
"ArgminModule_keepDim",
|
|
||||||
"ArgminModule_with_dim",
|
|
||||||
"AtenComplexImagModule_basic",
|
"AtenComplexImagModule_basic",
|
||||||
"AtenComplexRealModule_basic",
|
"AtenComplexRealModule_basic",
|
||||||
"AtenComplexViewModule_basic",
|
"AtenComplexViewModule_basic",
|
||||||
|
@ -3239,7 +3304,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ConvolutionModule2DTranspose_basic",
|
"ConvolutionModule2DTranspose_basic",
|
||||||
"CopyWithDifferentDTypesModule_basic",
|
"CopyWithDifferentDTypesModule_basic",
|
||||||
"CosineSimilarityStaticBroadcastModule_basic",
|
"CosineSimilarityStaticBroadcastModule_basic",
|
||||||
"CrossEntropyLossModule_basic",
|
|
||||||
"CumsumInputDtypeInt32Module_basic",
|
"CumsumInputDtypeInt32Module_basic",
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
"CumsumStaticModule_basic",
|
"CumsumStaticModule_basic",
|
||||||
|
@ -3483,9 +3547,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"LinalgVectorNormComplexModule_basic",
|
"LinalgVectorNormComplexModule_basic",
|
||||||
"LinspaceDtypeModule_basic",
|
"LinspaceDtypeModule_basic",
|
||||||
"LinspaceEmptyModule_basic",
|
"LinspaceEmptyModule_basic",
|
||||||
"LinspaceModule_basic",
|
|
||||||
"LinspaceOneSizeModule_basic",
|
"LinspaceOneSizeModule_basic",
|
||||||
"LinspaceTwoSizeModule_basic",
|
|
||||||
"MaskedFillTensorFloatValueModule_basic",
|
"MaskedFillTensorFloatValueModule_basic",
|
||||||
"MatmulBroadcastBatchDim_basic",
|
"MatmulBroadcastBatchDim_basic",
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
|
@ -3524,10 +3586,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
|
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
|
||||||
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
||||||
"MaxPool3dWithIndicesStaticModule_basic",
|
"MaxPool3dWithIndicesStaticModule_basic",
|
||||||
"MeanDimDtypeModule_basic",
|
|
||||||
"MeanDimEmptyDimModule_basic",
|
"MeanDimEmptyDimModule_basic",
|
||||||
"MeanDimNoneDimModule_basic",
|
"MeanDimNoneDimModule_basic",
|
||||||
"MeanDtypeModule_basic",
|
|
||||||
"MseLossMeanReductionModule_basic",
|
"MseLossMeanReductionModule_basic",
|
||||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||||
"MulFloatModule_basic",
|
"MulFloatModule_basic",
|
||||||
|
@ -3566,9 +3626,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"NllLossModuleBackwardWeight_basic",
|
"NllLossModuleBackwardWeight_basic",
|
||||||
"NllLossModuleBackward_basic",
|
"NllLossModuleBackward_basic",
|
||||||
"NllLossModuleBackward_ignore_index",
|
"NllLossModuleBackward_ignore_index",
|
||||||
"NllLossModule_1D_basic",
|
|
||||||
"NllLossModule_mean_basic",
|
|
||||||
"NllLossModule_sum_basic",
|
|
||||||
"NormScalarComplexModule_basic",
|
"NormScalarComplexModule_basic",
|
||||||
"NormScalarModule_basic",
|
"NormScalarModule_basic",
|
||||||
"NormScalarOptDimKeepDimComplexModule_basic",
|
"NormScalarOptDimKeepDimComplexModule_basic",
|
||||||
|
@ -3613,14 +3670,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"RandnLikeDtypeModule_basic",
|
"RandnLikeDtypeModule_basic",
|
||||||
"RandnLikeModule_basic",
|
"RandnLikeModule_basic",
|
||||||
"RandnModule_basic",
|
"RandnModule_basic",
|
||||||
"ReduceAllDimBool_basic",
|
|
||||||
"ReduceAllDimEmpty_basic",
|
"ReduceAllDimEmpty_basic",
|
||||||
"ReduceAllDimFloat_basic",
|
|
||||||
"ReduceAllDimInt_basic",
|
|
||||||
"ReduceAllFloatModule_basic",
|
|
||||||
"ReduceAllIntModule_basic",
|
|
||||||
"ReduceAnyFloatModule_basic",
|
|
||||||
"ReduceAnyIntModule_basic",
|
|
||||||
"ReduceFrobeniusNormComplexModule_basic",
|
"ReduceFrobeniusNormComplexModule_basic",
|
||||||
"ReduceL1NormComplexModule_basic",
|
"ReduceL1NormComplexModule_basic",
|
||||||
"ReduceL1NormWithDTypeModule_basic",
|
"ReduceL1NormWithDTypeModule_basic",
|
||||||
|
@ -3628,34 +3678,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ReduceL3NormAllDimsModule_basic",
|
"ReduceL3NormAllDimsModule_basic",
|
||||||
"ReduceL3NormKeepDimComplexModule_basic",
|
"ReduceL3NormKeepDimComplexModule_basic",
|
||||||
"ReduceL3NormKeepDimModule_basic",
|
"ReduceL3NormKeepDimModule_basic",
|
||||||
"ReduceMaxAllDims_basic",
|
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
"ReduceMaxFloatModule_basic",
|
|
||||||
"ReduceMaxSignedIntModule_basic",
|
|
||||||
"ReduceMaxUnsignedIntModule_basic",
|
|
||||||
"ReduceMinAlongDimNegative_basic",
|
|
||||||
"ReduceMinAlongDimSignedInt_basic",
|
|
||||||
"ReduceMinAlongDimUnsignedInt_basic",
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
"ReduceMinAlongDim_basic",
|
|
||||||
"ReduceMinFloatModule_basic",
|
|
||||||
"ReduceMinKeepDimReturnBoth_basic",
|
|
||||||
"ReduceMinKeepDim_basic",
|
|
||||||
"ReduceMinSignedIntModule_basic",
|
|
||||||
"ReduceMinUnsignedIntModule_basic",
|
|
||||||
"ReduceProdDimIntFloatModule_basic",
|
|
||||||
"ReduceProdDtypeFloatModule_basic",
|
|
||||||
"ReduceProdDtypeIntModule_basic",
|
|
||||||
"ReduceProdElementTypeBoolModule_basic",
|
|
||||||
"ReduceProdFloatModule_basic",
|
|
||||||
"ReduceProdSignedIntModule_basic",
|
|
||||||
"ReduceProdUnsignedIntModule_basic",
|
|
||||||
"ReduceSumDimIntListDtypeFloatModule_basic",
|
|
||||||
"ReduceSumDimIntListDtypeIntModule_basic",
|
|
||||||
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
|
||||||
"ReduceSumDimIntListEmptyDimModule_basic",
|
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||||
"ReduceSumDtypeFloatModule_basic",
|
|
||||||
"ReduceSumDtypeIntModule_basic",
|
|
||||||
"ReduceSumElementTypeBoolModule_basic",
|
|
||||||
"ReflectionPad1dModule2dInput_Right",
|
"ReflectionPad1dModule2dInput_Right",
|
||||||
"ReflectionPad1dModule2dInput_basic",
|
"ReflectionPad1dModule2dInput_basic",
|
||||||
"ReflectionPad1dModule3dInput_Left",
|
"ReflectionPad1dModule3dInput_Left",
|
||||||
|
@ -3672,7 +3697,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ReplicationPad2dModule_top0",
|
"ReplicationPad2dModule_top0",
|
||||||
"RollModule_basic",
|
"RollModule_basic",
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
"RsubIntModule_basic",
|
|
||||||
"RsubIntModule_noalpha_basic",
|
"RsubIntModule_noalpha_basic",
|
||||||
"ScalarConstantTupleModule_basic",
|
"ScalarConstantTupleModule_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
|
@ -3801,6 +3825,17 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"HstackBasicComplexModule_basic",
|
||||||
|
"HstackBasicFloatModule_basic",
|
||||||
|
"HstackBasicIntFloatModule_basic",
|
||||||
|
"HstackBasicIntModule_basic",
|
||||||
|
"Rot90BasicModule_basic",
|
||||||
|
"Rot90DynamicDimsModule_basic",
|
||||||
|
"Rot90MultipleRotationsModule_basic",
|
||||||
|
"Rot90NegativeEvenRotationsModule_basic",
|
||||||
|
"Rot90NegativeOddRotationsModule_basic",
|
||||||
|
"SafeSoftmaxModule_basic",
|
||||||
|
"SafeSoftmaxNonNoneDtypeModule_basic",
|
||||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||||
|
@ -3916,7 +3951,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ArgminIntModule_basic",
|
"ArgminIntModule_basic",
|
||||||
"ArgminIntModule_multiple_mins",
|
"ArgminIntModule_multiple_mins",
|
||||||
"ArgminModule_basic",
|
"ArgminModule_basic",
|
||||||
"ArgminModule_keepDim",
|
|
||||||
"ArgminModule_with_dim",
|
"ArgminModule_with_dim",
|
||||||
"AtenComplex64Module_basic",
|
"AtenComplex64Module_basic",
|
||||||
"AtenComplexImagModule_basic",
|
"AtenComplexImagModule_basic",
|
||||||
|
@ -4162,7 +4196,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseExpm1Module_basic",
|
"ElementwiseExpm1Module_basic",
|
||||||
"ElementwiseFlattenBroadcastModule_basic",
|
"ElementwiseFlattenBroadcastModule_basic",
|
||||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||||
"ElementwiseFmodTensor_Float_basic",
|
|
||||||
"ElementwiseFmodTensor_Int_Float_basic",
|
"ElementwiseFmodTensor_Int_Float_basic",
|
||||||
"ElementwiseFmodTensor_Int_basic",
|
"ElementwiseFmodTensor_Int_basic",
|
||||||
"ElementwiseGeFloatIntScalarModule_basic",
|
"ElementwiseGeFloatIntScalarModule_basic",
|
||||||
|
@ -4624,7 +4657,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ScalarImplicitIntModule_basic",
|
"ScalarImplicitIntModule_basic",
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||||
"ScatterReduceFloatMaxModule",
|
"ScatterReduceFloatMaxModule",
|
||||||
|
|
|
@ -1373,3 +1373,100 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v
|
||||||
%0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32>
|
%0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32>
|
||||||
return %0 : !torch.vtensor<[2,4],si32>
|
return %0 : !torch.vtensor<[2,4],si32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.min.dim$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||||
|
// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||||
|
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%values, %indices = torch.aten.min.dim %0, %int2, %true : !torch.vtensor<[3,2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],f32>, !torch.vtensor<[3,2,1],si64>
|
||||||
|
%1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||||
|
return %1 : tensor<3x2x1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.min$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.reduce_min %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1>} : (tensor<1x1x1xf32>) -> tensor<1xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.min$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
|
%0 = torch.aten.min %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.max$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = tosa.reduce_max %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.reduce_max %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1>} : (tensor<1x1x1xf32>) -> tensor<1xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.max$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
|
%0 = torch.aten.max %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.prod.dim_int$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,2,1],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.prod.dim_int$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> {
|
||||||
|
%dim = torch.constant.int 2
|
||||||
|
%keepdims = torch.constant.bool true
|
||||||
|
%dtype = torch.constant.none
|
||||||
|
%0 = torch.aten.prod.dim_int %arg0, %dim, %keepdims, %dtype: !torch.vtensor<[3,2,3],f32> , !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,2,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.all.dim$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],i1> -> tensor<3x2x3xi1>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.reduce_all %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xi1>) -> tensor<3x2x1xi1>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x2x1xi1> -> !torch.vtensor<[3,2,1],i1>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,2,1],i1>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> {
|
||||||
|
%dim = torch.constant.int 2
|
||||||
|
%keepdims = torch.constant.bool true
|
||||||
|
%0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1>
|
||||||
|
return %0 : !torch.vtensor<[3,2,1],i1>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue