mirror of https://github.com/llvm/torch-mlir
[TOSA] Extend Torch to TOSA legalization coverage (#3718)
- Add Torch to TOSA legalization for the following ops: + aten.logical_not + aten.logical_xor + aten.cos + aten.sin + aten.pow.Scalar + aten.pow.Tensor_Tensor + aten.erf + aten.bitwise_and.Scalar + aten.bitwise_left_shift.Tensor + aten.bitwise_right_shift.Tensor + aten.le.Tensor + aten.le.Scalar - Update e2e tests in xfail_sets - Update basic.mlir with newly legalized ops Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: I4aa5790073ef2e5ec0e9b374da42887242f8dabc Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3763/head
parent
abaff58c6d
commit
3f79a2982a
|
@ -105,9 +105,18 @@ public:
|
|||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
auto binaryOp =
|
||||
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
|
||||
rewriter.replaceOp(op, binaryOp.getResult());
|
||||
Value binaryOp;
|
||||
|
||||
// TOSA ArithmeticRightShiftOp has a round parameter.
|
||||
if constexpr (std::is_same<AtenOpT, AtenBitwiseRightShiftTensorOp>()) {
|
||||
binaryOp = rewriter.create<TosaOpT>(op->getLoc(), outTy, lhs, rhs,
|
||||
/*round=*/false);
|
||||
} else {
|
||||
binaryOp =
|
||||
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, binaryOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -353,6 +362,7 @@ public:
|
|||
// For bitwise operators, only integer datatype legalization is supported
|
||||
constexpr bool isBitwiseOp =
|
||||
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenBitwiseAndScalarOp>() ||
|
||||
std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenBitwiseXorTensorOp>();
|
||||
if (isa<mlir::FloatType>(lhsElemTy) && isBitwiseOp) {
|
||||
|
@ -372,7 +382,9 @@ public:
|
|||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
// There is no Lesser operator in TOSA.
|
||||
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>() ||
|
||||
std::is_same<AtenOpT, AtenLeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLeScalarOp>());
|
||||
|
||||
// Promote lhs and rhs dtypes for bitwise operators.
|
||||
TensorType resultTy = cast<TensorType>(
|
||||
|
@ -688,39 +700,30 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||
AtenTanhOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
// Sigmoid legalization in TOSA for quantized element-type uses specialized
|
||||
// tosa.table construct.
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization currently supported");
|
||||
}
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenActivationFunctionOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(op, "Only Tensor types supported");
|
||||
|
||||
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization currently supported");
|
||||
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op, this->getTypeConverter()->convertType(op.getType()), self);
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
|
||||
AtenSigmoidOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = cast<TensorType>(self.getType());
|
||||
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
// Sigmoid legalization in TOSA for quantized element-type uses
|
||||
// specialized tosa.table construct.
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization currently supported");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||
|
@ -1205,39 +1208,63 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
auto outType =
|
||||
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Pow");
|
||||
Value selfTensor;
|
||||
if constexpr (std::is_same<AtenOpT, AtenPowScalarOp>()) {
|
||||
Value selfScalar = op.getSelf();
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor,
|
||||
outType.getElementType(), {})))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA PowScalar operation");
|
||||
} else {
|
||||
selfTensor = adaptor.getSelf();
|
||||
auto selfTy = cast<RankedTensorType>(selfTensor.getType());
|
||||
|
||||
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization supported");
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Pow");
|
||||
|
||||
auto outType =
|
||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization supported");
|
||||
}
|
||||
|
||||
Value expTensor;
|
||||
Value expScalar = op.getExponent();
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
|
||||
outType.getElementType(), {})))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Pow operation");
|
||||
Value expTensor;
|
||||
if constexpr (std::is_same<AtenOpT, AtenPowTensorScalarOp>()) {
|
||||
Value expScalar = op.getExponent();
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
|
||||
outType.getElementType(), {})))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Pow operation");
|
||||
} else {
|
||||
expTensor = adaptor.getExponent();
|
||||
auto expTy = cast<RankedTensorType>(expTensor.getType());
|
||||
|
||||
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
|
||||
self, expTensor);
|
||||
rewriter.replaceOp(op, powOp.getResult());
|
||||
if (!expTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Pow");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(
|
||||
rewriter, op, outType, selfTensor, expTensor);
|
||||
rewriter.replaceOp(op, powOp.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Perform the basic n-dim matmul operation encompassing the handling of
|
||||
// broadcasting and dynamic shape propagation.
|
||||
|
@ -4243,32 +4270,6 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
|
||||
AtenLeTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
auto otherType = dyn_cast<TensorType>(adaptor.getOther().getType());
|
||||
if (!otherType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types condition are currently supported");
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
auto greaterOp = rewriter.create<tosa::GreaterOp>(
|
||||
op.getLoc(), outType, adaptor.getSelf(), adaptor.getOther());
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, outType,
|
||||
greaterOp.getOutput());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
|
||||
AtenIscloseOp op, OpAdaptor adaptor,
|
||||
|
@ -5815,6 +5816,9 @@ public:
|
|||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
||||
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
|
||||
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
|
||||
INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp)
|
||||
INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp)
|
||||
INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp)
|
||||
#undef INSERT_UNARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
|
||||
|
@ -5823,6 +5827,11 @@ public:
|
|||
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
|
||||
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
|
||||
tosa::LogicalLeftShiftOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
|
||||
tosa::ArithmeticRightShiftOp)
|
||||
#undef INSERT_BINARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
|
||||
|
@ -5843,11 +5852,14 @@ public:
|
|||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp)
|
||||
#undef INSERT_BINARY_COMPARE_PATTERN
|
||||
|
@ -5987,16 +5999,30 @@ public:
|
|||
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
|
||||
#undef INSERT_MASKED_FILL_PATTERN
|
||||
|
||||
#define INSERT_POW_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
|
||||
INSERT_POW_OP_PATTERN(AtenPowScalarOp);
|
||||
#undef INSERT_POW_OP_PATTERN
|
||||
|
||||
#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
|
||||
context);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp);
|
||||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
|
||||
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
||||
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
|
||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
|
@ -6023,7 +6049,6 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
|
||||
INSERT_ATENOP_PATTERN(AtenAbsOp);
|
||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLeTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
|
|
|
@ -1702,6 +1702,35 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"TriuBroadcastModule_basic",
|
||||
"TriuModule_basic",
|
||||
"AtenHannWindowPeriodicFalseModule_basic",
|
||||
"AtenHannWindowPeriodicTrueModule_basic",
|
||||
"ElementwiseAndScalarModule_basic",
|
||||
"ElementwiseAndScalarStaticShapeModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt32Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||
"ElementwiseCosModule_basic",
|
||||
"ElementwiseErfModule_basic",
|
||||
"ElementwiseLeFloatIntScalarModule_basic",
|
||||
"ElementwiseLeFloatScalarModule_basic",
|
||||
"ElementwiseLeFloatTensorNanModule_basic",
|
||||
"ElementwiseLeIntScalarModule_basic",
|
||||
"ElementwiseLeMixedIntScalarModule_basic",
|
||||
"ElementwisePowScalarModule_basic",
|
||||
"ElementwisePowTensorBroadcastModule_basic",
|
||||
"ElementwisePowTensorBroadcastStaticModule_basic",
|
||||
"ElementwisePowTensorModule_basic",
|
||||
"ElementwisePowTensorStaticModule_basic",
|
||||
"ElementwiseSinModule_basic",
|
||||
"ArgminIntModule_basic",
|
||||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
|
@ -2245,6 +2274,8 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
| {
|
||||
### Tests additionally passing in make_fx_tosa
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
"ArgminIntModule_basic",
|
||||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
|
@ -3200,10 +3231,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"MultinomialModule_basic",
|
||||
"RenormModuleFloat16_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScatterAddStaticModule_basic",
|
||||
"TensorsConcatComplex128FloatModule_basic",
|
||||
"TensorsConcatComplex128IntModule_basic",
|
||||
|
@ -3254,8 +3281,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AtenEyeMModuleInt2D_basic",
|
||||
"AtenEyeModuleInt2D_basic",
|
||||
"AtenFloatScalarModule_basic",
|
||||
"AtenHannWindowPeriodicTrueModule_basic",
|
||||
"AtenHannWindowPeriodicFalseModule_basic",
|
||||
"AtenIntBoolOpConstFalseModule_basic",
|
||||
"AtenIntBoolOpConstTrueModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
|
@ -3373,8 +3398,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAcoshIntModule_basic",
|
||||
"ElementwiseAcoshModule_basic",
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
"ElementwiseAndScalarModule_basic",
|
||||
"ElementwiseAndScalarStaticShapeModule_basic",
|
||||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAsinModule_basic",
|
||||
"ElementwiseAsinhIntModule_basic",
|
||||
|
@ -3392,44 +3415,23 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt32Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||
"ElementwiseClampMinTensorFloatModule_basic",
|
||||
"ElementwiseClampMinTensorIntModule_basic",
|
||||
"ElementwiseClampTensorFloatModule_basic",
|
||||
"ElementwiseClampTensorIntModule_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseCosModule_basic",
|
||||
"ElementwiseCoshIntModule_basic",
|
||||
"ElementwiseCoshModule_basic",
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseErfModule_basic",
|
||||
"ElementwiseExpIntModule_basic",
|
||||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseGeluApproximateTanhModule_basic",
|
||||
"ElementwiseHardshrinkModule_basic",
|
||||
"ElementwiseHardshrinkStaticModule_basic",
|
||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||
"ElementwiseLeFloatIntScalarModule_basic",
|
||||
"ElementwiseLeFloatScalarModule_basic",
|
||||
"ElementwiseLeFloatTensorNanModule_basic",
|
||||
"ElementwiseLeIntScalarModule_basic",
|
||||
"ElementwiseLeMixedIntScalarModule_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog10Module_basic",
|
||||
"ElementwiseLog1pModule_basic",
|
||||
|
@ -3440,18 +3442,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseMishModule_basic",
|
||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwisePowScalarModule_basic",
|
||||
"ElementwisePowTensorBroadcastModule_basic",
|
||||
"ElementwisePowTensorBroadcastStaticModule_basic",
|
||||
"ElementwisePowTensorModule_basic",
|
||||
"ElementwisePowTensorStaticModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
"ElementwiseRsqrtIntModule_basic",
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
"ElementwiseSinIntModule_basic",
|
||||
"ElementwiseSinModule_basic",
|
||||
"ElementwiseSinhIntModule_basic",
|
||||
"ElementwiseSinhModule_basic",
|
||||
"ElementwiseTanIntModule_basic",
|
||||
|
@ -4169,9 +4165,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
|
||||
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
|
||||
"ElementwiseAtenLogicalOrOpRandomModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseAndModule_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||
|
@ -4190,7 +4184,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseClampModule_basic",
|
||||
"ElementwiseClampTensorInt8Module_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseCosModule_basic",
|
||||
"ElementwiseCoshIntModule_basic",
|
||||
"ElementwiseCoshModule_basic",
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
|
@ -4210,7 +4203,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseEqBoolScalarModule_basic",
|
||||
"ElementwiseEqDiffWidthScalarModule_basic",
|
||||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseErfModule_basic",
|
||||
"ElementwiseExpIntModule_basic",
|
||||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
|
@ -4222,7 +4214,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseGtMixed2ScalarModule_basic",
|
||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||
"ElementwiseIsinfModule_basic",
|
||||
"ElementwiseLeFloatTensorNanModule_basic",
|
||||
"ElementwiseLeMixedIntScalarModule_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
|
@ -4237,12 +4228,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseNanToNumModule_Basic",
|
||||
"ElementwiseOrTensorModule_basic",
|
||||
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||
"ElementwisePowModule_basic",
|
||||
"ElementwisePowScalarModule_basic",
|
||||
"ElementwisePowTensorBroadcastModule_basic",
|
||||
"ElementwisePowTensorBroadcastStaticModule_basic",
|
||||
"ElementwisePowTensorModule_basic",
|
||||
"ElementwisePowTensorStaticModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
|
@ -4255,7 +4240,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseSgnModule_basic",
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
"ElementwiseSinIntModule_basic",
|
||||
"ElementwiseSinModule_basic",
|
||||
"ElementwiseSinhIntModule_basic",
|
||||
"ElementwiseSinhModule_basic",
|
||||
"ElementwiseSqrtIntModule_basic",
|
||||
|
@ -4414,8 +4398,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"LinalgNormKeepDimComplexModule_basic",
|
||||
"LinalgNormModule_basic",
|
||||
"LinalgVectorNormComplexModule_basic",
|
||||
"LinalgVectorNormKeepDimModule_basic",
|
||||
"LinalgVectorNormModule_basic",
|
||||
"LogSoftmaxBackwardModule_basic",
|
||||
"LogSoftmaxIntModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
|
@ -4503,8 +4485,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"NativeGroupNormBackwardModule_basic",
|
||||
"NativeGroupNormModule_basic",
|
||||
"NativeLayerNormDynamicModule_basic",
|
||||
"NativeLayerNormModule4D_basic",
|
||||
"NativeLayerNormModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"NewEmptyStridedModuleDefaultDtype_basic",
|
||||
|
|
|
@ -1672,3 +1672,190 @@ func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !tor
|
|||
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
|
||||
return %0 : !torch.vtensor<[2, 4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_not(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
|
||||
// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<4x5xi1>) -> tensor<4x5xi1>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[4,5],i1>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.logical_not(%arg0: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
|
||||
%0 = torch.aten.logical_not %arg0 : !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1>
|
||||
return %0 : !torch.vtensor<[4,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.cos(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.cos(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
%0 = torch.aten.cos %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sin(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.sin(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
%0 = torch.aten.sin %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.pow.Scalar(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_1]] : (tensor<f32>, tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.pow.Scalar(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
%float2.000000e00 = torch.constant.float 2.000000e+00
|
||||
%0 = torch.aten.pow.Scalar %float2.000000e00, %arg0 : !torch.float, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.pow.Tensor_Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.erf$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = tosa.erf %[[VAL_1]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.erf %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.bitwise_and.Scalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_3]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.bitwise_and.Scalar %arg0, %int2 : !torch.vtensor<[?,?],si32>, !torch.int -> !torch.vtensor<[?,?],si32>
|
||||
return %0 : !torch.vtensor<[?,?],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.le.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.le.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.le.Scalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_1]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.le.Scalar %arg0, %int2 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_xor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.logical_xor %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.logical_xor$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.logical_left_shift %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.bitwise_left_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||
%0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||
return %0: !torch.vtensor<[?,?],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.arithmetic_right_shift %[[VAL_3]], %[[VAL_2]] {round = false} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||
return %0: !torch.vtensor<[?,?],si32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue