[TOSA] Extend Torch to TOSA legalization coverage (#3718)

- Add Torch to TOSA legalization for the following ops:
  + aten.logical_not
  + aten.logical_xor
  + aten.cos
  + aten.sin
  + aten.pow.Scalar
  + aten.pow.Tensor_Tensor
  + aten.erf
  + aten.bitwise_and.Scalar
  + aten.bitwise_left_shift.Tensor
  + aten.bitwise_right_shift.Tensor
  + aten.le.Tensor
  + aten.le.Scalar
- Update e2e tests in xfail_sets
- Update basic.mlir with newly legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I4aa5790073ef2e5ec0e9b374da42887242f8dabc

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3727/head
Justin Ngo 2024-09-20 14:33:55 -07:00 committed by GitHub
parent abaff58c6d
commit 3f79a2982a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 333 additions and 141 deletions

View File

@ -105,9 +105,18 @@ public:
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
auto binaryOp =
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
rewriter.replaceOp(op, binaryOp.getResult());
Value binaryOp;
// TOSA ArithmeticRightShiftOp has a round parameter.
if constexpr (std::is_same<AtenOpT, AtenBitwiseRightShiftTensorOp>()) {
binaryOp = rewriter.create<TosaOpT>(op->getLoc(), outTy, lhs, rhs,
/*round=*/false);
} else {
binaryOp =
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
}
rewriter.replaceOp(op, binaryOp);
return success();
}
};
@ -353,6 +362,7 @@ public:
// For bitwise operators, only integer datatype legalization is supported
constexpr bool isBitwiseOp =
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseAndScalarOp>() ||
std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseXorTensorOp>();
if (isa<mlir::FloatType>(lhsElemTy) && isBitwiseOp) {
@ -372,7 +382,9 @@ public:
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
// There is no Lesser operator in TOSA.
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>());
std::is_same<AtenOpT, AtenLtScalarOp>() ||
std::is_same<AtenOpT, AtenLeTensorOp>() ||
std::is_same<AtenOpT, AtenLeScalarOp>());
// Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = cast<TensorType>(
@ -688,39 +700,30 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
template <>
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
// Sigmoid legalization in TOSA for quantized element-type uses specialized
// tosa.table construct.
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenActivationFunctionOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
if (!selfTy)
return rewriter.notifyMatchFailure(op, "Only Tensor types supported");
if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
rewriter.replaceOpWithNewOp<TosaOpT>(
op, this->getTypeConverter()->convertType(op.getType()), self);
template <>
LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
AtenSigmoidOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
// Sigmoid legalization in TOSA for quantized element-type uses
// specialized tosa.table construct.
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
}
};
template <>
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
@ -1205,39 +1208,63 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
}
};
template <>
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
AtenPowTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
template <typename AtenOpT>
class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());
auto outType =
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");
Value selfTensor;
if constexpr (std::is_same<AtenOpT, AtenPowScalarOp>()) {
Value selfScalar = op.getSelf();
if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor,
outType.getElementType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA PowScalar operation");
} else {
selfTensor = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(selfTensor.getType());
if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");
auto outType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
}
Value expTensor;
Value expScalar = op.getExponent();
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
outType.getElementType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");
Value expTensor;
if constexpr (std::is_same<AtenOpT, AtenPowTensorScalarOp>()) {
Value expScalar = op.getExponent();
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
outType.getElementType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");
} else {
expTensor = adaptor.getExponent();
auto expTy = cast<RankedTensorType>(expTensor.getType());
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
self, expTensor);
rewriter.replaceOp(op, powOp.getResult());
if (!expTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");
}
return success();
}
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(
rewriter, op, outType, selfTensor, expTensor);
rewriter.replaceOp(op, powOp.getResult());
return success();
}
};
// Perform the basic n-dim matmul operation encompassing the handling of
// broadcasting and dynamic shape propagation.
@ -4243,32 +4270,6 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
AtenLeTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
auto otherType = dyn_cast<TensorType>(adaptor.getOther().getType());
if (!otherType)
return rewriter.notifyMatchFailure(
op, "Only tensor types condition are currently supported");
auto outType = getTypeConverter()->convertType(op.getType());
auto greaterOp = rewriter.create<tosa::GreaterOp>(
op.getLoc(), outType, adaptor.getSelf(), adaptor.getOther());
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, outType,
greaterOp.getOutput());
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
AtenIscloseOp op, OpAdaptor adaptor,
@ -5815,6 +5816,9 @@ public:
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp)
INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp)
INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp)
#undef INSERT_UNARY_PATTERN
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
@ -5823,6 +5827,11 @@ public:
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
tosa::LogicalLeftShiftOp)
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
tosa::ArithmeticRightShiftOp)
#undef INSERT_BINARY_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
@ -5843,11 +5852,14 @@ public:
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp)
#undef INSERT_BINARY_COMPARE_PATTERN
@ -5987,16 +5999,30 @@ public:
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
#undef INSERT_MASKED_FILL_PATTERN
#define INSERT_POW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
INSERT_POW_OP_PATTERN(AtenPowScalarOp);
#undef INSERT_POW_OP_PATTERN
#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
@ -6023,7 +6049,6 @@ public:
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenAbsOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenLeTensorOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);

View File

@ -1702,6 +1702,35 @@ TOSA_PASS_SET = {
"ElementwiseRemainderTensorModule_Int_basic",
"TriuBroadcastModule_basic",
"TriuModule_basic",
"AtenHannWindowPeriodicFalseModule_basic",
"AtenHannWindowPeriodicTrueModule_basic",
"ElementwiseAndScalarModule_basic",
"ElementwiseAndScalarStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpModule_basic",
"ElementwiseAtenLogicalXorOpModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseLeftShiftInt32Module_basic",
"ElementwiseBitwiseLeftShiftInt64Module_basic",
"ElementwiseBitwiseLeftShiftInt8Module_basic",
"ElementwiseBitwiseRightShiftInt32Module_basic",
"ElementwiseBitwiseRightShiftInt64Module_basic",
"ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseCosModule_basic",
"ElementwiseErfModule_basic",
"ElementwiseLeFloatIntScalarModule_basic",
"ElementwiseLeFloatScalarModule_basic",
"ElementwiseLeFloatTensorNanModule_basic",
"ElementwiseLeIntScalarModule_basic",
"ElementwiseLeMixedIntScalarModule_basic",
"ElementwisePowScalarModule_basic",
"ElementwisePowTensorBroadcastModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwiseSinModule_basic",
"ArgminIntModule_basic",
"ArgminIntModule_multiple_mins",
"ArgminModule_basic",
@ -2245,6 +2274,8 @@ MAKE_FX_TOSA_PASS_SET = (
| {
### Tests additionally passing in make_fx_tosa
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ArgminIntModule_basic",
"ArgminIntModule_multiple_mins",
"ArgminModule_basic",
@ -3200,10 +3231,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"MultinomialModule_basic",
"RenormModuleFloat16_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScatterAddStaticModule_basic",
"TensorsConcatComplex128FloatModule_basic",
"TensorsConcatComplex128IntModule_basic",
@ -3254,8 +3281,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenFloatScalarModule_basic",
"AtenHannWindowPeriodicTrueModule_basic",
"AtenHannWindowPeriodicFalseModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
"AtenIntBoolOpModule_basic",
@ -3373,8 +3398,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseAcoshIntModule_basic",
"ElementwiseAcoshModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAndScalarModule_basic",
"ElementwiseAndScalarStaticShapeModule_basic",
"ElementwiseAsinIntModule_basic",
"ElementwiseAsinModule_basic",
"ElementwiseAsinhIntModule_basic",
@ -3392,44 +3415,23 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseAtenLogicalXorOpModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseBitwiseLeftShiftInt32Module_basic",
"ElementwiseBitwiseLeftShiftInt64Module_basic",
"ElementwiseBitwiseLeftShiftInt8Module_basic",
"ElementwiseBitwiseRightShiftInt32Module_basic",
"ElementwiseBitwiseRightShiftInt64Module_basic",
"ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorIntModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseCoshIntModule_basic",
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseErfModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseHardshrinkModule_basic",
"ElementwiseHardshrinkStaticModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseLeFloatIntScalarModule_basic",
"ElementwiseLeFloatScalarModule_basic",
"ElementwiseLeFloatTensorNanModule_basic",
"ElementwiseLeIntScalarModule_basic",
"ElementwiseLeMixedIntScalarModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog10Module_basic",
"ElementwiseLog1pModule_basic",
@ -3440,18 +3442,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseMishModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwisePowScalarModule_basic",
"ElementwisePowTensorBroadcastModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
"ElementwiseRsqrtIntModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseTanIntModule_basic",
@ -4169,9 +4165,7 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
"ElementwiseAtenLogicalOrOpRandomModule_basic",
"ElementwiseAtenLogicalXorOpModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseLeftShiftInt32Module_basic",
"ElementwiseBitwiseLeftShiftInt64Module_basic",
@ -4190,7 +4184,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseClampModule_basic",
"ElementwiseClampTensorInt8Module_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseCoshIntModule_basic",
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
@ -4210,7 +4203,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseEqBoolScalarModule_basic",
"ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseErfModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
@ -4222,7 +4214,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseGtMixed2ScalarModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseIsinfModule_basic",
"ElementwiseLeFloatTensorNanModule_basic",
"ElementwiseLeMixedIntScalarModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic",
@ -4237,12 +4228,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseNanToNumModule_Basic",
"ElementwiseOrTensorModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwisePowModule_basic",
"ElementwisePowScalarModule_basic",
"ElementwisePowTensorBroadcastModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
@ -4255,7 +4240,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseSgnModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseSqrtIntModule_basic",
@ -4414,8 +4398,6 @@ ONNX_TOSA_XFAIL_SET = {
"LinalgNormKeepDimComplexModule_basic",
"LinalgNormModule_basic",
"LinalgVectorNormComplexModule_basic",
"LinalgVectorNormKeepDimModule_basic",
"LinalgVectorNormModule_basic",
"LogSoftmaxBackwardModule_basic",
"LogSoftmaxIntModule_basic",
"MaskedFillTensorFloatValueModule_basic",
@ -4503,8 +4485,6 @@ ONNX_TOSA_XFAIL_SET = {
"NativeGroupNormBackwardModule_basic",
"NativeGroupNormModule_basic",
"NativeLayerNormDynamicModule_basic",
"NativeLayerNormModule4D_basic",
"NativeLayerNormModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewEmptyStridedModuleDefaultDtype_basic",

View File

@ -1672,3 +1672,190 @@ func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !tor
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
return %0 : !torch.vtensor<[2, 4],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_not(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<4x5xi1>) -> tensor<4x5xi1>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[4,5],i1>
// CHECK: }
func.func @torch.aten.logical_not(%arg0: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
%0 = torch.aten.logical_not %arg0 : !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1>
return %0 : !torch.vtensor<[4,5],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.cos(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
// CHECK: %[[VAL_2:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func.func @torch.aten.cos(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
%0 = torch.aten.cos %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.sin(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
// CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func.func @torch.aten.sin(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
%0 = torch.aten.sin %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.pow.Scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_1]] : (tensor<f32>, tensor<3x4xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func.func @torch.aten.pow.Scalar(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
%float2.000000e00 = torch.constant.float 2.000000e+00
%0 = torch.aten.pow.Scalar %float2.000000e00, %arg0 : !torch.float, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.pow.Tensor_Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.erf$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = tosa.erf %[[VAL_1]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.erf %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.bitwise_and.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_3]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
// CHECK: }
func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
%int2 = torch.constant.int 2
%0 = torch.aten.bitwise_and.Scalar %arg0, %int2 : !torch.vtensor<[?,?],si32>, !torch.int -> !torch.vtensor<[?,?],si32>
return %0 : !torch.vtensor<[?,?],si32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.le.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.le.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.le.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_1]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%int2 = torch.constant.int 2
%0 = torch.aten.le.Scalar %arg0, %int2 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_xor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[VAL_4:.*]] = tosa.logical_xor %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func.func @torch.aten.logical_xor$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[VAL_4:.*]] = tosa.logical_left_shift %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
// CHECK: }
func.func @torch.aten.bitwise_left_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
%0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
return %0: !torch.vtensor<[?,?],si32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[VAL_4:.*]] = tosa.arithmetic_right_shift %[[VAL_3]], %[[VAL_2]] {round = false} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
// CHECK: }
func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
return %0: !torch.vtensor<[?,?],si32>
}