mirror of https://github.com/llvm/torch-mlir
[TOSA] Extend Torch to TOSA legalization coverage (#3718)
- Add Torch to TOSA legalization for the following ops: + aten.logical_not + aten.logical_xor + aten.cos + aten.sin + aten.pow.Scalar + aten.pow.Tensor_Tensor + aten.erf + aten.bitwise_and.Scalar + aten.bitwise_left_shift.Tensor + aten.bitwise_right_shift.Tensor + aten.le.Tensor + aten.le.Scalar - Update e2e tests in xfail_sets - Update basic.mlir with newly legalized ops Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: I4aa5790073ef2e5ec0e9b374da42887242f8dabc Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3763/head
parent
abaff58c6d
commit
3f79a2982a
|
@ -105,9 +105,18 @@ public:
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()));
|
op.getType()));
|
||||||
|
|
||||||
auto binaryOp =
|
Value binaryOp;
|
||||||
|
|
||||||
|
// TOSA ArithmeticRightShiftOp has a round parameter.
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenBitwiseRightShiftTensorOp>()) {
|
||||||
|
binaryOp = rewriter.create<TosaOpT>(op->getLoc(), outTy, lhs, rhs,
|
||||||
|
/*round=*/false);
|
||||||
|
} else {
|
||||||
|
binaryOp =
|
||||||
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
|
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
|
||||||
rewriter.replaceOp(op, binaryOp.getResult());
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, binaryOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -353,6 +362,7 @@ public:
|
||||||
// For bitwise operators, only integer datatype legalization is supported
|
// For bitwise operators, only integer datatype legalization is supported
|
||||||
constexpr bool isBitwiseOp =
|
constexpr bool isBitwiseOp =
|
||||||
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
|
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenBitwiseAndScalarOp>() ||
|
||||||
std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() ||
|
std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenBitwiseXorTensorOp>();
|
std::is_same<AtenOpT, AtenBitwiseXorTensorOp>();
|
||||||
if (isa<mlir::FloatType>(lhsElemTy) && isBitwiseOp) {
|
if (isa<mlir::FloatType>(lhsElemTy) && isBitwiseOp) {
|
||||||
|
@ -372,7 +382,9 @@ public:
|
||||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||||
// There is no Lesser operator in TOSA.
|
// There is no Lesser operator in TOSA.
|
||||||
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
std::is_same<AtenOpT, AtenLtScalarOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenLeTensorOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenLeScalarOp>());
|
||||||
|
|
||||||
// Promote lhs and rhs dtypes for bitwise operators.
|
// Promote lhs and rhs dtypes for bitwise operators.
|
||||||
TensorType resultTy = cast<TensorType>(
|
TensorType resultTy = cast<TensorType>(
|
||||||
|
@ -688,39 +700,30 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override;
|
ConversionPatternRewriter &rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename AtenOpT, typename TosaOpT>
|
||||||
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
class ConvertAtenActivationFunctionOp : public OpConversionPattern<AtenOpT> {
|
||||||
AtenTanhOp op, OpAdaptor adaptor,
|
public:
|
||||||
ConversionPatternRewriter &rewriter) const {
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = cast<TensorType>(self.getType());
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
// Sigmoid legalization in TOSA for quantized element-type uses specialized
|
|
||||||
// tosa.table construct.
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Only floating-point datatype legalization currently supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
if (!selfTy)
|
||||||
LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
|
return rewriter.notifyMatchFailure(op, "Only Tensor types supported");
|
||||||
AtenSigmoidOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||||
Value self = adaptor.getSelf();
|
|
||||||
auto selfTy = cast<TensorType>(self.getType());
|
|
||||||
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
|
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
// Sigmoid legalization in TOSA for quantized element-type uses
|
|
||||||
// specialized tosa.table construct.
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point datatype legalization currently supported");
|
op, "Only floating-point datatype legalization currently supported");
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||||
|
op, this->getTypeConverter()->convertType(op.getType()), self);
|
||||||
|
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||||
|
@ -1205,13 +1208,29 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename AtenOpT>
|
||||||
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
|
||||||
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
public:
|
||||||
ConversionPatternRewriter &rewriter) const {
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
Value self = adaptor.getSelf();
|
auto outType =
|
||||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
|
Value selfTensor;
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenPowScalarOp>()) {
|
||||||
|
Value selfScalar = op.getSelf();
|
||||||
|
if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor,
|
||||||
|
outType.getElementType(), {})))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only scalar constants are supported for "
|
||||||
|
"conversion in TOSA PowScalar operation");
|
||||||
|
} else {
|
||||||
|
selfTensor = adaptor.getSelf();
|
||||||
|
auto selfTy = cast<RankedTensorType>(selfTensor.getType());
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1220,24 +1239,32 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point datatype legalization supported");
|
op, "Only floating-point datatype legalization supported");
|
||||||
|
}
|
||||||
auto outType =
|
|
||||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
|
||||||
|
|
||||||
Value expTensor;
|
Value expTensor;
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenPowTensorScalarOp>()) {
|
||||||
Value expScalar = op.getExponent();
|
Value expScalar = op.getExponent();
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
|
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
|
||||||
outType.getElementType(), {})))
|
outType.getElementType(), {})))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Currently only scalar constants are supported for "
|
op, "Currently only scalar constants are supported for "
|
||||||
"conversion in TOSA Pow operation");
|
"conversion in TOSA Pow operation");
|
||||||
|
} else {
|
||||||
|
expTensor = adaptor.getExponent();
|
||||||
|
auto expTy = cast<RankedTensorType>(expTensor.getType());
|
||||||
|
|
||||||
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
|
if (!expTy)
|
||||||
self, expTensor);
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor types supported in TOSA Pow");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(
|
||||||
|
rewriter, op, outType, selfTensor, expTensor);
|
||||||
rewriter.replaceOp(op, powOp.getResult());
|
rewriter.replaceOp(op, powOp.getResult());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Perform the basic n-dim matmul operation encompassing the handling of
|
// Perform the basic n-dim matmul operation encompassing the handling of
|
||||||
// broadcasting and dynamic shape propagation.
|
// broadcasting and dynamic shape propagation.
|
||||||
|
@ -4243,32 +4270,6 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
|
||||||
LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
|
|
||||||
AtenLeTensorOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
|
|
||||||
// Not a tensor type.
|
|
||||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
|
||||||
if (!selfType)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Only tensor types input are currently supported");
|
|
||||||
auto otherType = dyn_cast<TensorType>(adaptor.getOther().getType());
|
|
||||||
if (!otherType)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Only tensor types condition are currently supported");
|
|
||||||
|
|
||||||
auto outType = getTypeConverter()->convertType(op.getType());
|
|
||||||
|
|
||||||
auto greaterOp = rewriter.create<tosa::GreaterOp>(
|
|
||||||
op.getLoc(), outType, adaptor.getSelf(), adaptor.getOther());
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, outType,
|
|
||||||
greaterOp.getOutput());
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
|
||||||
AtenIscloseOp op, OpAdaptor adaptor,
|
AtenIscloseOp op, OpAdaptor adaptor,
|
||||||
|
@ -5815,6 +5816,9 @@ public:
|
||||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
||||||
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
|
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
|
||||||
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
|
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
|
||||||
|
INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp)
|
||||||
|
INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp)
|
||||||
|
INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp)
|
||||||
#undef INSERT_UNARY_PATTERN
|
#undef INSERT_UNARY_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
|
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
|
||||||
|
@ -5823,6 +5827,11 @@ public:
|
||||||
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
|
INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp)
|
||||||
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
||||||
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
|
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
|
||||||
|
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
|
||||||
|
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
|
||||||
|
tosa::LogicalLeftShiftOp)
|
||||||
|
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
|
||||||
|
tosa::ArithmeticRightShiftOp)
|
||||||
#undef INSERT_BINARY_PATTERN
|
#undef INSERT_BINARY_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
|
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
|
||||||
|
@ -5843,11 +5852,14 @@ public:
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
||||||
|
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp)
|
||||||
|
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
|
||||||
|
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp)
|
||||||
#undef INSERT_BINARY_COMPARE_PATTERN
|
#undef INSERT_BINARY_COMPARE_PATTERN
|
||||||
|
@ -5987,16 +5999,30 @@ public:
|
||||||
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
|
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
|
||||||
#undef INSERT_MASKED_FILL_PATTERN
|
#undef INSERT_MASKED_FILL_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_POW_OP_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
|
||||||
|
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
|
||||||
|
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
|
||||||
|
INSERT_POW_OP_PATTERN(AtenPowScalarOp);
|
||||||
|
#undef INSERT_POW_OP_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
|
||||||
|
context);
|
||||||
|
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
|
||||||
|
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp);
|
||||||
|
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
|
||||||
|
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
|
||||||
|
|
||||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||||
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
|
||||||
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
|
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
|
|
||||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
|
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
|
||||||
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
|
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
|
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
|
||||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||||
|
@ -6023,7 +6049,6 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
|
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenAbsOp);
|
INSERT_ATENOP_PATTERN(AtenAbsOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenLeTensorOp);
|
|
||||||
INSERT_ATENOP_PATTERN(AtenClampOp);
|
INSERT_ATENOP_PATTERN(AtenClampOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
||||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||||
|
|
|
@ -1702,6 +1702,35 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseRemainderTensorModule_Int_basic",
|
"ElementwiseRemainderTensorModule_Int_basic",
|
||||||
"TriuBroadcastModule_basic",
|
"TriuBroadcastModule_basic",
|
||||||
"TriuModule_basic",
|
"TriuModule_basic",
|
||||||
|
"AtenHannWindowPeriodicFalseModule_basic",
|
||||||
|
"AtenHannWindowPeriodicTrueModule_basic",
|
||||||
|
"ElementwiseAndScalarModule_basic",
|
||||||
|
"ElementwiseAndScalarStaticShapeModule_basic",
|
||||||
|
"ElementwiseAtenLogicalNotOpModule_basic",
|
||||||
|
"ElementwiseAtenLogicalXorOpModule_basic",
|
||||||
|
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
||||||
|
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
||||||
|
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||||
|
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||||
|
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||||
|
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||||
|
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
||||||
|
"ElementwiseBitwiseRightShiftInt32Module_basic",
|
||||||
|
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
||||||
|
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||||
|
"ElementwiseCosModule_basic",
|
||||||
|
"ElementwiseErfModule_basic",
|
||||||
|
"ElementwiseLeFloatIntScalarModule_basic",
|
||||||
|
"ElementwiseLeFloatScalarModule_basic",
|
||||||
|
"ElementwiseLeFloatTensorNanModule_basic",
|
||||||
|
"ElementwiseLeIntScalarModule_basic",
|
||||||
|
"ElementwiseLeMixedIntScalarModule_basic",
|
||||||
|
"ElementwisePowScalarModule_basic",
|
||||||
|
"ElementwisePowTensorBroadcastModule_basic",
|
||||||
|
"ElementwisePowTensorBroadcastStaticModule_basic",
|
||||||
|
"ElementwisePowTensorModule_basic",
|
||||||
|
"ElementwisePowTensorStaticModule_basic",
|
||||||
|
"ElementwiseSinModule_basic",
|
||||||
"ArgminIntModule_basic",
|
"ArgminIntModule_basic",
|
||||||
"ArgminIntModule_multiple_mins",
|
"ArgminIntModule_multiple_mins",
|
||||||
"ArgminModule_basic",
|
"ArgminModule_basic",
|
||||||
|
@ -2245,6 +2274,8 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
| {
|
| {
|
||||||
### Tests additionally passing in make_fx_tosa
|
### Tests additionally passing in make_fx_tosa
|
||||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||||
"ArgminIntModule_basic",
|
"ArgminIntModule_basic",
|
||||||
"ArgminIntModule_multiple_mins",
|
"ArgminIntModule_multiple_mins",
|
||||||
"ArgminModule_basic",
|
"ArgminModule_basic",
|
||||||
|
@ -3200,10 +3231,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"MultinomialModule_basic",
|
"MultinomialModule_basic",
|
||||||
"RenormModuleFloat16_basic",
|
"RenormModuleFloat16_basic",
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
|
||||||
"ScatterAddStaticModule_basic",
|
"ScatterAddStaticModule_basic",
|
||||||
"TensorsConcatComplex128FloatModule_basic",
|
"TensorsConcatComplex128FloatModule_basic",
|
||||||
"TensorsConcatComplex128IntModule_basic",
|
"TensorsConcatComplex128IntModule_basic",
|
||||||
|
@ -3254,8 +3281,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AtenEyeMModuleInt2D_basic",
|
"AtenEyeMModuleInt2D_basic",
|
||||||
"AtenEyeModuleInt2D_basic",
|
"AtenEyeModuleInt2D_basic",
|
||||||
"AtenFloatScalarModule_basic",
|
"AtenFloatScalarModule_basic",
|
||||||
"AtenHannWindowPeriodicTrueModule_basic",
|
|
||||||
"AtenHannWindowPeriodicFalseModule_basic",
|
|
||||||
"AtenIntBoolOpConstFalseModule_basic",
|
"AtenIntBoolOpConstFalseModule_basic",
|
||||||
"AtenIntBoolOpConstTrueModule_basic",
|
"AtenIntBoolOpConstTrueModule_basic",
|
||||||
"AtenIntBoolOpModule_basic",
|
"AtenIntBoolOpModule_basic",
|
||||||
|
@ -3373,8 +3398,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAcoshIntModule_basic",
|
"ElementwiseAcoshIntModule_basic",
|
||||||
"ElementwiseAcoshModule_basic",
|
"ElementwiseAcoshModule_basic",
|
||||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||||
"ElementwiseAndScalarModule_basic",
|
|
||||||
"ElementwiseAndScalarStaticShapeModule_basic",
|
|
||||||
"ElementwiseAsinIntModule_basic",
|
"ElementwiseAsinIntModule_basic",
|
||||||
"ElementwiseAsinModule_basic",
|
"ElementwiseAsinModule_basic",
|
||||||
"ElementwiseAsinhIntModule_basic",
|
"ElementwiseAsinhIntModule_basic",
|
||||||
|
@ -3392,44 +3415,23 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||||
"ElementwiseAtenLogicalNotOpModule_basic",
|
|
||||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||||
"ElementwiseAtenLogicalXorOpModule_basic",
|
|
||||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
|
||||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
|
||||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
|
||||||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
|
||||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
|
||||||
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
|
||||||
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
|
||||||
"ElementwiseBitwiseRightShiftInt32Module_basic",
|
|
||||||
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
|
||||||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
|
||||||
"ElementwiseClampMinTensorFloatModule_basic",
|
"ElementwiseClampMinTensorFloatModule_basic",
|
||||||
"ElementwiseClampMinTensorIntModule_basic",
|
"ElementwiseClampMinTensorIntModule_basic",
|
||||||
"ElementwiseClampTensorFloatModule_basic",
|
"ElementwiseClampTensorFloatModule_basic",
|
||||||
"ElementwiseClampTensorIntModule_basic",
|
"ElementwiseClampTensorIntModule_basic",
|
||||||
"ElementwiseCosIntModule_basic",
|
"ElementwiseCosIntModule_basic",
|
||||||
"ElementwiseCosModule_basic",
|
|
||||||
"ElementwiseCoshIntModule_basic",
|
"ElementwiseCoshIntModule_basic",
|
||||||
"ElementwiseCoshModule_basic",
|
"ElementwiseCoshModule_basic",
|
||||||
"ElementwiseDequantizePerChannelModule_basic",
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
"ElementwiseDequantizePerTensorModule_basic",
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
"ElementwiseErfIntModule_basic",
|
"ElementwiseErfIntModule_basic",
|
||||||
"ElementwiseErfModule_basic",
|
|
||||||
"ElementwiseExpIntModule_basic",
|
"ElementwiseExpIntModule_basic",
|
||||||
"ElementwiseExpm1IntModule_basic",
|
"ElementwiseExpm1IntModule_basic",
|
||||||
"ElementwiseExpm1Module_basic",
|
"ElementwiseExpm1Module_basic",
|
||||||
"ElementwiseGeluApproximateTanhModule_basic",
|
"ElementwiseGeluApproximateTanhModule_basic",
|
||||||
"ElementwiseHardshrinkModule_basic",
|
|
||||||
"ElementwiseHardshrinkStaticModule_basic",
|
|
||||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||||
"ElementwiseLeFloatIntScalarModule_basic",
|
|
||||||
"ElementwiseLeFloatScalarModule_basic",
|
|
||||||
"ElementwiseLeFloatTensorNanModule_basic",
|
|
||||||
"ElementwiseLeIntScalarModule_basic",
|
|
||||||
"ElementwiseLeMixedIntScalarModule_basic",
|
|
||||||
"ElementwiseLog10IntModule_basic",
|
"ElementwiseLog10IntModule_basic",
|
||||||
"ElementwiseLog10Module_basic",
|
"ElementwiseLog10Module_basic",
|
||||||
"ElementwiseLog1pModule_basic",
|
"ElementwiseLog1pModule_basic",
|
||||||
|
@ -3440,18 +3442,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseMishModule_basic",
|
"ElementwiseMishModule_basic",
|
||||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||||
"ElementwiseMulTensorComplexModule_basic",
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
"ElementwisePowScalarModule_basic",
|
|
||||||
"ElementwisePowTensorBroadcastModule_basic",
|
|
||||||
"ElementwisePowTensorBroadcastStaticModule_basic",
|
|
||||||
"ElementwisePowTensorModule_basic",
|
|
||||||
"ElementwisePowTensorStaticModule_basic",
|
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
"ElementwiseReciprocalIntModule_basic",
|
"ElementwiseReciprocalIntModule_basic",
|
||||||
"ElementwiseRsqrtIntModule_basic",
|
"ElementwiseRsqrtIntModule_basic",
|
||||||
"ElementwiseSigmoidIntModule_basic",
|
"ElementwiseSigmoidIntModule_basic",
|
||||||
"ElementwiseSinIntModule_basic",
|
"ElementwiseSinIntModule_basic",
|
||||||
"ElementwiseSinModule_basic",
|
|
||||||
"ElementwiseSinhIntModule_basic",
|
"ElementwiseSinhIntModule_basic",
|
||||||
"ElementwiseSinhModule_basic",
|
"ElementwiseSinhModule_basic",
|
||||||
"ElementwiseTanIntModule_basic",
|
"ElementwiseTanIntModule_basic",
|
||||||
|
@ -4169,9 +4165,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
|
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
|
||||||
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
|
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
|
||||||
"ElementwiseAtenLogicalOrOpRandomModule_basic",
|
"ElementwiseAtenLogicalOrOpRandomModule_basic",
|
||||||
"ElementwiseAtenLogicalXorOpModule_basic",
|
|
||||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
"ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic",
|
||||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
|
||||||
"ElementwiseBitwiseAndModule_basic",
|
"ElementwiseBitwiseAndModule_basic",
|
||||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||||
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||||
|
@ -4190,7 +4184,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseClampModule_basic",
|
"ElementwiseClampModule_basic",
|
||||||
"ElementwiseClampTensorInt8Module_basic",
|
"ElementwiseClampTensorInt8Module_basic",
|
||||||
"ElementwiseCosIntModule_basic",
|
"ElementwiseCosIntModule_basic",
|
||||||
"ElementwiseCosModule_basic",
|
|
||||||
"ElementwiseCoshIntModule_basic",
|
"ElementwiseCoshIntModule_basic",
|
||||||
"ElementwiseCoshModule_basic",
|
"ElementwiseCoshModule_basic",
|
||||||
"ElementwiseDequantizePerChannelModule_basic",
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
|
@ -4210,7 +4203,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseEqBoolScalarModule_basic",
|
"ElementwiseEqBoolScalarModule_basic",
|
||||||
"ElementwiseEqDiffWidthScalarModule_basic",
|
"ElementwiseEqDiffWidthScalarModule_basic",
|
||||||
"ElementwiseErfIntModule_basic",
|
"ElementwiseErfIntModule_basic",
|
||||||
"ElementwiseErfModule_basic",
|
|
||||||
"ElementwiseExpIntModule_basic",
|
"ElementwiseExpIntModule_basic",
|
||||||
"ElementwiseExpm1IntModule_basic",
|
"ElementwiseExpm1IntModule_basic",
|
||||||
"ElementwiseExpm1Module_basic",
|
"ElementwiseExpm1Module_basic",
|
||||||
|
@ -4222,7 +4214,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseGtMixed2ScalarModule_basic",
|
"ElementwiseGtMixed2ScalarModule_basic",
|
||||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||||
"ElementwiseIsinfModule_basic",
|
"ElementwiseIsinfModule_basic",
|
||||||
"ElementwiseLeFloatTensorNanModule_basic",
|
|
||||||
"ElementwiseLeMixedIntScalarModule_basic",
|
"ElementwiseLeMixedIntScalarModule_basic",
|
||||||
"ElementwiseLog10IntModule_basic",
|
"ElementwiseLog10IntModule_basic",
|
||||||
"ElementwiseLog2IntModule_basic",
|
"ElementwiseLog2IntModule_basic",
|
||||||
|
@ -4237,12 +4228,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseNanToNumModule_Basic",
|
"ElementwiseNanToNumModule_Basic",
|
||||||
"ElementwiseOrTensorModule_basic",
|
"ElementwiseOrTensorModule_basic",
|
||||||
"ElementwiseOrTensorStaticShapeModule_basic",
|
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||||
"ElementwisePowModule_basic",
|
|
||||||
"ElementwisePowScalarModule_basic",
|
|
||||||
"ElementwisePowTensorBroadcastModule_basic",
|
|
||||||
"ElementwisePowTensorBroadcastStaticModule_basic",
|
|
||||||
"ElementwisePowTensorModule_basic",
|
|
||||||
"ElementwisePowTensorStaticModule_basic",
|
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
"ElementwiseReciprocalIntModule_basic",
|
"ElementwiseReciprocalIntModule_basic",
|
||||||
|
@ -4255,7 +4240,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseSgnModule_basic",
|
"ElementwiseSgnModule_basic",
|
||||||
"ElementwiseSigmoidIntModule_basic",
|
"ElementwiseSigmoidIntModule_basic",
|
||||||
"ElementwiseSinIntModule_basic",
|
"ElementwiseSinIntModule_basic",
|
||||||
"ElementwiseSinModule_basic",
|
|
||||||
"ElementwiseSinhIntModule_basic",
|
"ElementwiseSinhIntModule_basic",
|
||||||
"ElementwiseSinhModule_basic",
|
"ElementwiseSinhModule_basic",
|
||||||
"ElementwiseSqrtIntModule_basic",
|
"ElementwiseSqrtIntModule_basic",
|
||||||
|
@ -4414,8 +4398,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"LinalgNormKeepDimComplexModule_basic",
|
"LinalgNormKeepDimComplexModule_basic",
|
||||||
"LinalgNormModule_basic",
|
"LinalgNormModule_basic",
|
||||||
"LinalgVectorNormComplexModule_basic",
|
"LinalgVectorNormComplexModule_basic",
|
||||||
"LinalgVectorNormKeepDimModule_basic",
|
|
||||||
"LinalgVectorNormModule_basic",
|
|
||||||
"LogSoftmaxBackwardModule_basic",
|
"LogSoftmaxBackwardModule_basic",
|
||||||
"LogSoftmaxIntModule_basic",
|
"LogSoftmaxIntModule_basic",
|
||||||
"MaskedFillTensorFloatValueModule_basic",
|
"MaskedFillTensorFloatValueModule_basic",
|
||||||
|
@ -4503,8 +4485,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"NativeGroupNormBackwardModule_basic",
|
"NativeGroupNormBackwardModule_basic",
|
||||||
"NativeGroupNormModule_basic",
|
"NativeGroupNormModule_basic",
|
||||||
"NativeLayerNormDynamicModule_basic",
|
"NativeLayerNormDynamicModule_basic",
|
||||||
"NativeLayerNormModule4D_basic",
|
|
||||||
"NativeLayerNormModule_basic",
|
|
||||||
"NeFloatIntModule_basic",
|
"NeFloatIntModule_basic",
|
||||||
"NeIntModule_basic",
|
"NeIntModule_basic",
|
||||||
"NewEmptyStridedModuleDefaultDtype_basic",
|
"NewEmptyStridedModuleDefaultDtype_basic",
|
||||||
|
|
|
@ -1672,3 +1672,190 @@ func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !tor
|
||||||
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
|
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
|
||||||
return %0 : !torch.vtensor<[2, 4],f32>
|
return %0 : !torch.vtensor<[2, 4],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.logical_not(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<4x5xi1>) -> tensor<4x5xi1>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1>
|
||||||
|
// CHECK: return %[[VAL_3]] : !torch.vtensor<[4,5],i1>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.logical_not(%arg0: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
|
||||||
|
%0 = torch.aten.logical_not %arg0 : !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,5],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.cos(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.cos(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
%0 = torch.aten.cos %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.sin(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.sin(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
%0 = torch.aten.sin %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.pow.Scalar(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_1]] : (tensor<f32>, tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.pow.Scalar(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
%float2.000000e00 = torch.constant.float 2.000000e+00
|
||||||
|
%0 = torch.aten.pow.Scalar %float2.000000e00, %arg0 : !torch.float, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.pow.Tensor_Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.erf$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = tosa.erf %[[VAL_1]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.erf %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.bitwise_and.Scalar$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_3]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.bitwise_and.Scalar %arg0, %int2 : !torch.vtensor<[?,?],si32>, !torch.int -> !torch.vtensor<[?,?],si32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.le.Tensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
%0 = torch.aten.le.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||||
|
return %0 : !torch.vtensor<[?,?],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.le.Scalar$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_1]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.le.Scalar %arg0, %int2 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1>
|
||||||
|
return %0 : !torch.vtensor<[?,?],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.logical_xor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.logical_xor %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.logical_xor$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
%0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
|
||||||
|
return %0 : !torch.vtensor<[?,?],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.logical_left_shift %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.bitwise_left_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||||
|
%0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||||
|
return %0: !torch.vtensor<[?,?],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.arithmetic_right_shift %[[VAL_3]], %[[VAL_2]] {round = false} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||||
|
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||||
|
return %0: !torch.vtensor<[?,?],si32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue