mirror of https://github.com/llvm/torch-mlir
[TOSA] Add div rounding mode, remainder, fmod, and ge.Tensor ops support (#3717)
- Add legalization for aten.div rounding mode: + trunc: rounds division results towards zero + floor: rounds division results down - Add legalization for aten.remainder.Scalar and aten.fmod ops - Add legalization for aten.ge.Tensor op - Update e2e tests in xfail_sets.py - Update basic.mlir with new legalized ops Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: Icedd23205254fb893ce6f3de08956772b83b4320 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3718/head
parent
5ce48dfacd
commit
abaff58c6d
|
@ -463,6 +463,119 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Function to perform division with trunc rounding mode (rounding result
|
||||||
|
// towards zero) for float type inputs.
|
||||||
|
// This function takes in the division result between lhs and rhs rather
|
||||||
|
// than takes in the original lhs and rhs tensors as parameters.
|
||||||
|
Value truncFloatDivWithDivResult(PatternRewriter &rewriter, Operation *op,
|
||||||
|
TensorType outType, Value divResult) {
|
||||||
|
// To implement trunc mode for float inputs, multiply the floored abs
|
||||||
|
// of the tensor with the elementwise signedness of the tensor.
|
||||||
|
// div_result = lhs / rhs
|
||||||
|
// trunc_val = floor(abs(div_result)) * sign(div_result)
|
||||||
|
auto zero =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0, {}, outType.getElementType())
|
||||||
|
.value();
|
||||||
|
|
||||||
|
auto one =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 1, {}, outType.getElementType())
|
||||||
|
.value();
|
||||||
|
|
||||||
|
auto minusOne = tosa::getConstTensor<float>(rewriter, op, -1, {},
|
||||||
|
outType.getElementType())
|
||||||
|
.value();
|
||||||
|
|
||||||
|
auto cond = rewriter.create<tosa::GreaterEqualOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)),
|
||||||
|
divResult, zero);
|
||||||
|
|
||||||
|
auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), outType, cond,
|
||||||
|
one, minusOne);
|
||||||
|
|
||||||
|
auto absDivResult =
|
||||||
|
rewriter.create<tosa::AbsOp>(op->getLoc(), outType, divResult);
|
||||||
|
|
||||||
|
auto flooredAbsDivResult =
|
||||||
|
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, absDivResult);
|
||||||
|
|
||||||
|
Value result =
|
||||||
|
tosa::createMulOpAndCast(rewriter, op, outType, flooredAbsDivResult,
|
||||||
|
selectOp, /*shift=*/0)
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to perform division with trunc rounding mode (rounding result
|
||||||
|
// towards zero) for float type inputs
|
||||||
|
Value truncFloatDiv(PatternRewriter &rewriter, Operation *op,
|
||||||
|
TensorType outType, Value lhs, Value rhs) {
|
||||||
|
rhs = tosa::promoteType(rewriter, rhs, outType);
|
||||||
|
|
||||||
|
auto rhsRcp =
|
||||||
|
rewriter.create<tosa::ReciprocalOp>(op->getLoc(), rhs.getType(), rhs);
|
||||||
|
|
||||||
|
auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp,
|
||||||
|
/*shift=*/0);
|
||||||
|
|
||||||
|
return truncFloatDivWithDivResult(rewriter, op, outType, divResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to perform division with floor rounding mode (rounding result
|
||||||
|
// down) for integer type inputs.
|
||||||
|
Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType,
|
||||||
|
Value lhs, Value rhs) {
|
||||||
|
// To implement floor mode int input, utilize tosa::IntDivOp (trunc div
|
||||||
|
// result) with the following formula elementwise:
|
||||||
|
// floor_val = trunc_val - ((trunc_val * rhs != lhs)
|
||||||
|
// && (sign(lhs) != sign(rhs)))
|
||||||
|
|
||||||
|
// TOSA IntDiv requires inputs to be i32
|
||||||
|
auto i32Type =
|
||||||
|
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32));
|
||||||
|
lhs = tosa::promoteType(rewriter, lhs, i32Type);
|
||||||
|
rhs = tosa::promoteType(rewriter, rhs, i32Type);
|
||||||
|
|
||||||
|
auto intDivOp =
|
||||||
|
rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type, lhs, rhs);
|
||||||
|
|
||||||
|
auto zero = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
|
||||||
|
|
||||||
|
auto one = tosa::getConstTensor<int32_t>(rewriter, op, 1, {}).value();
|
||||||
|
|
||||||
|
auto boolType =
|
||||||
|
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1));
|
||||||
|
|
||||||
|
auto lhsMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type, lhs, rhs,
|
||||||
|
/*shift=*/0);
|
||||||
|
|
||||||
|
auto lhsRhsDifferentSign =
|
||||||
|
rewriter.create<tosa::GreaterOp>(op->getLoc(), boolType, zero, lhsMulRhs);
|
||||||
|
|
||||||
|
auto truncMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type,
|
||||||
|
intDivOp, rhs, /*shift=*/0);
|
||||||
|
|
||||||
|
auto truncMulRhsEqualLhs =
|
||||||
|
rewriter.create<tosa::EqualOp>(op->getLoc(), boolType, truncMulRhs, lhs);
|
||||||
|
|
||||||
|
auto truncMulRhsNotEqualLhs = rewriter.create<tosa::LogicalNotOp>(
|
||||||
|
op->getLoc(), boolType, truncMulRhsEqualLhs);
|
||||||
|
|
||||||
|
auto truncMinusOne =
|
||||||
|
rewriter.create<tosa::SubOp>(op->getLoc(), i32Type, intDivOp, one);
|
||||||
|
|
||||||
|
auto cond = rewriter.create<tosa::LogicalAndOp>(
|
||||||
|
op->getLoc(), boolType, lhsRhsDifferentSign, truncMulRhsNotEqualLhs);
|
||||||
|
|
||||||
|
auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), i32Type, cond,
|
||||||
|
truncMinusOne, intDivOp);
|
||||||
|
|
||||||
|
Value result = tosa::promoteType(rewriter, selectOp, outType);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename AtenOpT>
|
template <typename AtenOpT>
|
||||||
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
|
@ -498,25 +611,64 @@ public:
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()));
|
op.getType()));
|
||||||
|
|
||||||
// auto result;
|
// Get rounding mode for aten.div.Tensor_mode
|
||||||
|
std::string roundMode;
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenDivTensorModeOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenDivScalarModeOp>()) {
|
||||||
|
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundMode)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Non-const rounding mode parameter unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
Value result;
|
Value result;
|
||||||
if (isa<mlir::FloatType>(outType.getElementType())) {
|
if (isa<mlir::FloatType>(outType.getElementType())) {
|
||||||
// The input to the reciprocal is an integer sometimes, and we may need to
|
// The input to the reciprocal is an integer sometimes, and we may need
|
||||||
// promote it to a floating point. Per TOSA specification, the input types
|
// to promote it to a floating point. Per TOSA specification, the input
|
||||||
// can only be floating point for tosa::ReciprocalOp.
|
// types can only be floating point for tosa::ReciprocalOp.
|
||||||
Value rhsCasted = tosa::promoteType(rewriter, rhsTensor, outType);
|
rhsTensor = tosa::promoteType(rewriter, rhsTensor, outType);
|
||||||
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
|
auto rhsRcp = rewriter.create<tosa::ReciprocalOp>(
|
||||||
op->getLoc(), rhsCasted.getType(), rhsCasted);
|
op->getLoc(), rhsTensor.getType(), rhsTensor);
|
||||||
|
|
||||||
result = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
||||||
rcpOp.getResult(), /*shift=*/0)
|
rhsRcp, /*shift=*/0);
|
||||||
.getResult();
|
|
||||||
|
// Round result based on rounding mode
|
||||||
|
if (roundMode.compare("floor") == 0) {
|
||||||
|
// "floor": rounds the results of the division down. Equivalent to
|
||||||
|
// floor division in Python (the // operator).
|
||||||
|
auto floorOp =
|
||||||
|
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, divResult);
|
||||||
|
|
||||||
|
result = floorOp.getResult();
|
||||||
|
} else if (roundMode.compare("trunc") == 0) {
|
||||||
|
// "trunc": rounds the results of the division towards zero. Equivalent
|
||||||
|
// to C-style integer division.
|
||||||
|
result = truncFloatDivWithDivResult(rewriter, op, outType, divResult);
|
||||||
|
} else {
|
||||||
|
// None: No rounding mode
|
||||||
|
result = divResult.getResult();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// The output type can be different than the input types (e.g. dividing an
|
if (roundMode.compare("floor") == 0) {
|
||||||
// int tensor results in a floating point tensor).
|
// "floor": rounds the results of the division down. Equivalent to floor
|
||||||
result = tosa::createBinaryOpAndCast<tosa::IntDivOp>(
|
// division in Python (the // operator).
|
||||||
rewriter, op, outType, lhs, rhsTensor)
|
result = floorIntDiv(rewriter, op, outType, lhs, rhsTensor);
|
||||||
.getResult();
|
} else {
|
||||||
|
// "trunc": rounds the results of the division towards zero. Equivalent
|
||||||
|
// to C-style integer division.
|
||||||
|
// None: no rounding mode.
|
||||||
|
|
||||||
|
// TOSA IntDiv requires inputs to be i32
|
||||||
|
auto i32Type = RankedTensorType::get(outType.getShape(),
|
||||||
|
rewriter.getIntegerType(32));
|
||||||
|
lhs = tosa::promoteType(rewriter, lhs, i32Type);
|
||||||
|
rhsTensor = tosa::promoteType(rewriter, rhsTensor, i32Type);
|
||||||
|
|
||||||
|
auto intDivOp = rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type,
|
||||||
|
lhs, rhsTensor);
|
||||||
|
|
||||||
|
result = tosa::promoteType(rewriter, intDivOp, outType);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, {result});
|
rewriter.replaceOp(op, {result});
|
||||||
|
@ -4524,56 +4676,94 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <typename AtenOpT>
|
||||||
LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
|
class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
|
||||||
AtenRemainderScalarOp op, OpAdaptor adaptor,
|
public:
|
||||||
ConversionPatternRewriter &rewriter) const {
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types supported in TOSA Remainder");
|
op, "Only ranked tensor types supported in TOSA Remainder/Fmod");
|
||||||
|
|
||||||
auto outType =
|
auto outType =
|
||||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat())
|
if (!outElemTy.isIntOrFloat())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point or integer datatype legalization supported");
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
|
|
||||||
Value otherTensor;
|
Value otherTensor;
|
||||||
Value other = op.getOther();
|
if constexpr (std::is_same<AtenOpT, AtenRemainderScalarOp>()) {
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
|
Value other = op.getOther();
|
||||||
outElemTy, {})))
|
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
|
||||||
return rewriter.notifyMatchFailure(
|
outElemTy, {})))
|
||||||
op, "Currently only scalar constants are supported for "
|
return rewriter.notifyMatchFailure(
|
||||||
"conversion in TOSA Remainder operation");
|
op, "Currently only scalar constants are supported for "
|
||||||
|
"conversion in TOSA Remainder/Fmod operation");
|
||||||
|
} else {
|
||||||
|
otherTensor = adaptor.getOther();
|
||||||
|
auto otherTy = cast<RankedTensorType>(otherTensor.getType());
|
||||||
|
|
||||||
if (selfTy.getElementType() != outElemTy)
|
if (!otherTy)
|
||||||
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor types supported in TOSA Remainder/Fmod");
|
||||||
|
}
|
||||||
|
|
||||||
auto divTensor = self;
|
constexpr bool isRemainderOp =
|
||||||
if (isa<mlir::FloatType>(outElemTy)) {
|
std::is_same<AtenOpT, AtenRemainderScalarOp>() ||
|
||||||
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
|
std::is_same<AtenOpT, AtenRemainderTensorOp>() ||
|
||||||
op.getLoc(), otherTensor.getType(), otherTensor);
|
std::is_same<AtenOpT, AtenRemainderIntOp>();
|
||||||
divTensor = rewriter.create<tosa::MulOp>(
|
|
||||||
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
|
if (selfTy.getElementType() != outElemTy)
|
||||||
divTensor = rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
|
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);
|
||||||
} else {
|
|
||||||
divTensor = rewriter.create<tosa::IntDivOp>(op.getLoc(), outType, self,
|
Value divTensor;
|
||||||
otherTensor);
|
if (isRemainderOp) {
|
||||||
|
// torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
|
||||||
|
if (isa<mlir::FloatType>(outElemTy)) {
|
||||||
|
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
|
||||||
|
op.getLoc(), otherTensor.getType(), otherTensor);
|
||||||
|
divTensor = rewriter.create<tosa::MulOp>(
|
||||||
|
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
|
||||||
|
divTensor =
|
||||||
|
rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
|
||||||
|
} else {
|
||||||
|
divTensor = floorIntDiv(rewriter, op, outType, self, otherTensor);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
|
||||||
|
if (isa<mlir::FloatType>(outElemTy)) {
|
||||||
|
divTensor = truncFloatDiv(rewriter, op, outType, self, otherTensor);
|
||||||
|
} else {
|
||||||
|
// TOSA IntDiv requires inputs to be i32
|
||||||
|
auto i32Type = RankedTensorType::get(outType.getShape(),
|
||||||
|
rewriter.getIntegerType(32));
|
||||||
|
self = tosa::promoteType(rewriter, self, i32Type);
|
||||||
|
otherTensor = tosa::promoteType(rewriter, otherTensor, i32Type);
|
||||||
|
|
||||||
|
auto intDivTensor = rewriter.create<tosa::IntDivOp>(
|
||||||
|
op->getLoc(), i32Type, self, otherTensor);
|
||||||
|
|
||||||
|
divTensor = tosa::promoteType(rewriter, intDivTensor, outType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mulTensor = rewriter.create<tosa::MulOp>(op.getLoc(), outType,
|
||||||
|
otherTensor, divTensor,
|
||||||
|
/*shift=*/0);
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
|
||||||
|
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
};
|
||||||
auto mulTensor =
|
|
||||||
rewriter.create<tosa::MulOp>(op.getLoc(), outType, otherTensor, divTensor,
|
|
||||||
/*shift=*/0);
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename AtenOpT, typename TosaOpT>
|
template <typename AtenOpT, typename TosaOpT>
|
||||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
@ -5649,6 +5839,7 @@ public:
|
||||||
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
|
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
|
||||||
|
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
||||||
|
@ -5673,8 +5864,19 @@ public:
|
||||||
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
|
||||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
|
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
|
||||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
|
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
|
||||||
|
INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp);
|
||||||
|
INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp);
|
||||||
#undef INSERT_BINARY_DIV_PATTERN
|
#undef INSERT_BINARY_DIV_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
|
||||||
|
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
|
||||||
|
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
|
||||||
|
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp);
|
||||||
|
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp);
|
||||||
|
#undef INSERT_REMAINDER_FMOD_OP_PATTERN
|
||||||
|
|
||||||
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
|
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||||
|
@ -5828,7 +6030,6 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenCopyOp);
|
INSERT_ATENOP_PATTERN(AtenCopyOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
|
|
||||||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||||
|
|
|
@ -1668,6 +1668,40 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
||||||
|
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||||
|
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||||
|
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||||
|
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeFloorModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||||
|
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||||
|
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
||||||
|
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||||
|
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||||
|
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
||||||
|
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||||
|
"ElementwiseGeFloatTensorModule_basic",
|
||||||
|
"ElementwiseGeIntTensorModule_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Float_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_basic",
|
||||||
|
"TriuBroadcastModule_basic",
|
||||||
|
"TriuModule_basic",
|
||||||
"ArgminIntModule_basic",
|
"ArgminIntModule_basic",
|
||||||
"ArgminIntModule_multiple_mins",
|
"ArgminIntModule_multiple_mins",
|
||||||
"ArgminModule_basic",
|
"ArgminModule_basic",
|
||||||
|
@ -2210,6 +2244,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
TOSA_PASS_SET
|
TOSA_PASS_SET
|
||||||
| {
|
| {
|
||||||
### Tests additionally passing in make_fx_tosa
|
### Tests additionally passing in make_fx_tosa
|
||||||
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
"ArgminIntModule_basic",
|
"ArgminIntModule_basic",
|
||||||
"ArgminIntModule_multiple_mins",
|
"ArgminIntModule_multiple_mins",
|
||||||
"ArgminModule_basic",
|
"ArgminModule_basic",
|
||||||
|
@ -2318,7 +2353,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
"ViewNoChange1dModule_basic",
|
"ViewNoChange1dModule_basic",
|
||||||
"ViewNoChange2dModule_basic",
|
"ViewNoChange2dModule_basic",
|
||||||
"ViewNoChange3dModule_basic",
|
"ViewNoChange3dModule_basic",
|
||||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
||||||
|
@ -3137,7 +3171,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"Rot90MultipleRotationsModule_basic",
|
"Rot90MultipleRotationsModule_basic",
|
||||||
"Rot90NegativeEvenRotationsModule_basic",
|
"Rot90NegativeEvenRotationsModule_basic",
|
||||||
"Rot90NegativeOddRotationsModule_basic",
|
"Rot90NegativeOddRotationsModule_basic",
|
||||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
|
||||||
"AtenIntMM_basic",
|
"AtenIntMM_basic",
|
||||||
"AtenKthvalueDynamicDimsModule_basic",
|
"AtenKthvalueDynamicDimsModule_basic",
|
||||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||||
|
@ -3153,15 +3186,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"EinsumStaticDiagonalDimensionModule_basic",
|
"EinsumStaticDiagonalDimensionModule_basic",
|
||||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
|
||||||
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
|
||||||
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
|
||||||
"ElementwiseRreluEvalModule_basic",
|
"ElementwiseRreluEvalModule_basic",
|
||||||
"ElementwiseRreluEvalStaticModule_basic",
|
"ElementwiseRreluEvalStaticModule_basic",
|
||||||
"ElementwiseRreluTrainModule_basic",
|
"ElementwiseRreluTrainModule_basic",
|
||||||
|
@ -3194,11 +3218,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"TriuIndicesNegativeOffsetModule_basic",
|
"TriuIndicesNegativeOffsetModule_basic",
|
||||||
"TypeConversionUint8ToF32Module_basic",
|
"TypeConversionUint8ToF32Module_basic",
|
||||||
"WeightNormInterfaceModule_basic",
|
"WeightNormInterfaceModule_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
|
||||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
|
||||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
|
||||||
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
|
||||||
"AdaptiveAvgPool2dDynamic_basic",
|
|
||||||
"AdaptiveAvgPool3dDynamicNoBatch_basic",
|
"AdaptiveAvgPool3dDynamicNoBatch_basic",
|
||||||
"AdaptiveAvgPool3dDynamic_basic",
|
"AdaptiveAvgPool3dDynamic_basic",
|
||||||
"AdaptiveMaxPool1dDynamicNoBatch_basic",
|
"AdaptiveMaxPool1dDynamicNoBatch_basic",
|
||||||
|
@ -3370,11 +3389,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAtanTensorIntModule_basic",
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
"ElementwiseAtanhIntModule_basic",
|
"ElementwiseAtanhIntModule_basic",
|
||||||
"ElementwiseAtanhModule_basic",
|
"ElementwiseAtanhModule_basic",
|
||||||
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
|
||||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
|
||||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
|
||||||
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
|
||||||
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
|
||||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||||
|
@ -3402,25 +3416,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseCoshModule_basic",
|
"ElementwiseCoshModule_basic",
|
||||||
"ElementwiseDequantizePerChannelModule_basic",
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
"ElementwiseDequantizePerTensorModule_basic",
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
|
||||||
"ElementwiseDivScalarRoundingModeFloorModule_basic",
|
|
||||||
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
|
|
||||||
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
|
||||||
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
|
||||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
|
||||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
|
||||||
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
|
||||||
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
|
||||||
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
|
||||||
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
|
||||||
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
|
||||||
"ElementwiseErfIntModule_basic",
|
"ElementwiseErfIntModule_basic",
|
||||||
"ElementwiseErfModule_basic",
|
"ElementwiseErfModule_basic",
|
||||||
"ElementwiseExpIntModule_basic",
|
"ElementwiseExpIntModule_basic",
|
||||||
"ElementwiseExpm1IntModule_basic",
|
"ElementwiseExpm1IntModule_basic",
|
||||||
"ElementwiseExpm1Module_basic",
|
"ElementwiseExpm1Module_basic",
|
||||||
"ElementwiseGeFloatTensorModule_basic",
|
|
||||||
"ElementwiseGeIntTensorModule_basic",
|
|
||||||
"ElementwiseGeluApproximateTanhModule_basic",
|
"ElementwiseGeluApproximateTanhModule_basic",
|
||||||
"ElementwiseHardshrinkModule_basic",
|
"ElementwiseHardshrinkModule_basic",
|
||||||
"ElementwiseHardshrinkStaticModule_basic",
|
"ElementwiseHardshrinkStaticModule_basic",
|
||||||
|
@ -3448,10 +3448,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
"ElementwiseReciprocalIntModule_basic",
|
"ElementwiseReciprocalIntModule_basic",
|
||||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Float_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Int_basic",
|
|
||||||
"ElementwiseRsqrtIntModule_basic",
|
"ElementwiseRsqrtIntModule_basic",
|
||||||
"ElementwiseSigmoidIntModule_basic",
|
"ElementwiseSigmoidIntModule_basic",
|
||||||
"ElementwiseSinIntModule_basic",
|
"ElementwiseSinIntModule_basic",
|
||||||
|
@ -3850,6 +3846,7 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
"HstackBasicComplexModule_basic",
|
"HstackBasicComplexModule_basic",
|
||||||
"HstackBasicFloatModule_basic",
|
"HstackBasicFloatModule_basic",
|
||||||
"HstackBasicIntFloatModule_basic",
|
"HstackBasicIntFloatModule_basic",
|
||||||
|
@ -3890,8 +3887,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
|
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
|
||||||
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||||
|
@ -4223,11 +4218,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||||
"ElementwiseFmodTensor_Int_Float_basic",
|
"ElementwiseFmodTensor_Int_Float_basic",
|
||||||
"ElementwiseFmodTensor_Int_basic",
|
"ElementwiseFmodTensor_Int_basic",
|
||||||
"ElementwiseGeFloatIntScalarModule_basic",
|
|
||||||
"ElementwiseGeFloatScalarModule_basic",
|
|
||||||
"ElementwiseGeFloatTensorModule_basic",
|
|
||||||
"ElementwiseGeIntScalarModule_basic",
|
|
||||||
"ElementwiseGeIntTensorModule_basic",
|
|
||||||
"ElementwiseGeMixedIntScalarModule_basic",
|
"ElementwiseGeMixedIntScalarModule_basic",
|
||||||
"ElementwiseGtMixed2ScalarModule_basic",
|
"ElementwiseGtMixed2ScalarModule_basic",
|
||||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||||
|
@ -4259,7 +4249,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseRelu6Module_basic",
|
"ElementwiseRelu6Module_basic",
|
||||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_basic",
|
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_basic",
|
"ElementwiseRemainderTensorModule_Int_basic",
|
||||||
"ElementwiseRsqrtIntModule_basic",
|
"ElementwiseRsqrtIntModule_basic",
|
||||||
|
@ -4682,7 +4671,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ScalarImplicitIntModule_basic",
|
"ScalarImplicitIntModule_basic",
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||||
"ScatterReduceFloatMaxModule",
|
"ScatterReduceFloatMaxModule",
|
||||||
|
@ -4819,8 +4807,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"TraceSignedIntModule_basic",
|
"TraceSignedIntModule_basic",
|
||||||
"TraceUnsignedIntModule_basic",
|
"TraceUnsignedIntModule_basic",
|
||||||
"TraceUnsignedIntModule_empty",
|
"TraceUnsignedIntModule_empty",
|
||||||
"TriuBroadcastModule_basic",
|
|
||||||
"TriuModule_basic",
|
|
||||||
"TupleModule_basic",
|
"TupleModule_basic",
|
||||||
"TypeAsDifferentModule_basic",
|
"TypeAsDifferentModule_basic",
|
||||||
"TypeConversionF32ToF64Module_basic",
|
"TypeConversionF32ToF64Module_basic",
|
||||||
|
|
|
@ -213,10 +213,10 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch
|
||||||
// CHECK-LABEL: func.func @torch.aten.div$basic(
|
// CHECK-LABEL: func.func @torch.aten.div$basic(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
@ -1470,3 +1470,205 @@ func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch
|
||||||
%0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1>
|
%0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1>
|
||||||
return %0 : !torch.vtensor<[3,2,1],i1>
|
return %0 : !torch.vtensor<[3,2,1],i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_trunc(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc"
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.select %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : (tensor<?x?xi1>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.abs %[[VAL_6]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.floor %[[VAL_12]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_13]], %[[VAL_11]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_15]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.div.Tensor_mode$float_trunc(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||||
|
%str = torch.constant.str "trunc"
|
||||||
|
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?, ?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_trunc(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc"
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||||
|
// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.div.Tensor_mode$int_trunc(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> {
|
||||||
|
%str = torch.constant.str "trunc"
|
||||||
|
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64>
|
||||||
|
return %0 : !torch.vtensor<[?, ?],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_floor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor"
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.div.Tensor_mode$float_floor(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||||
|
%str = torch.constant.str "floor"
|
||||||
|
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?, ?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_floor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor"
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]] {shift = 0 : i8} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.greater %[[VAL_8]], %[[VAL_10]] : (tensor<i32>, tensor<?x?xi32>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]] {shift = 0 : i8} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.equal %[[VAL_12]], %[[VAL_5]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = tosa.logical_not %[[VAL_13]] : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_7]], %[[VAL_9]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_14]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_15]], %[[VAL_7]] : (tensor<?x?xi1>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = tosa.cast %[[VAL_17]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||||
|
// CHECK: return %[[VAL_19]] : !torch.vtensor<[?,?],si64>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.div.Tensor_mode$int_floor(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> {
|
||||||
|
%str = torch.constant.str "floor"
|
||||||
|
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64>
|
||||||
|
return %0 : !torch.vtensor<[?, ?],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.str ""
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.div.Tensor_mode$float_basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||||
|
%str = torch.constant.str ""
|
||||||
|
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?, ?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.str ""
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||||
|
// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.div.Tensor_mode$int_basic(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> {
|
||||||
|
%str = torch.constant.str ""
|
||||||
|
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64>
|
||||||
|
return %0 : !torch.vtensor<[?, ?],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.ge.Tensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.ge.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
%0 = torch.aten.ge.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||||
|
return %0 : !torch.vtensor<[?,?],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.remainder.Tensor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||||
|
// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
|
||||||
|
%0 = torch.aten.remainder.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2, 4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.fmod.Tensor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_6]] : (tensor<2x4xf32>, tensor<f32>) -> tensor<2x4xi1>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : (tensor<2x4xi1>, tensor<f32>, tensor<f32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.abs %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.floor %[[VAL_11]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_10]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_2]], %[[VAL_13]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_3]], %[[VAL_14]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||||
|
// CHECK: return %[[VAL_16]] : !torch.vtensor<[2,4],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
|
||||||
|
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2, 4],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue