mirror of https://github.com/llvm/torch-mlir
[TOSA] Expand Torch to TOSA legalization coverage (#3827)
- Add/Extend Torch to TOSA legalization for the following ops: + Add aten.threshold_backward + Fix aten.threshold + Re-implement aten.broadcast_to using tosa.reshape and tosa.tile + Add support for rank 0 index for aten.index_select + Fix aten.index_put.hacked_twin + Add aten.uniform + Add aten.logical_and - Update xfail_sets.py with new e2e results - Add LIT tests to basic.mlir for newly added ops Change-Id: I8910564a049d18293284fe2e55e82bc1d2cf10e3 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3839/head
parent
a6292f38ca
commit
4dd213b042
|
@ -25,6 +25,7 @@
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <random>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -125,15 +126,14 @@ template <typename T>
|
||||||
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
||||||
const int64_t &intValue) {
|
const int64_t &intValue) {
|
||||||
if (isFloat) {
|
if (isFloat) {
|
||||||
// Do a round-trip check here instead of numeric limits due to
|
return (doubleValue >=
|
||||||
// compiler warnings around double <-> int conversion.
|
static_cast<double>(std::numeric_limits<T>::min())) &&
|
||||||
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue)));
|
(doubleValue <= static_cast<double>(std::numeric_limits<T>::max()));
|
||||||
} else {
|
} else if (isInt) {
|
||||||
assert(isInt);
|
|
||||||
return (intValue >= static_cast<int64_t>(std::numeric_limits<T>::min())) &&
|
return (intValue >= static_cast<int64_t>(std::numeric_limits<T>::min())) &&
|
||||||
(intValue <= static_cast<int64_t>(std::numeric_limits<T>::max()));
|
(intValue <= static_cast<int64_t>(std::numeric_limits<T>::max()));
|
||||||
}
|
}
|
||||||
return true;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: This will eventually go into a Tosa*Utils file.
|
// FIXME: This will eventually go into a Tosa*Utils file.
|
||||||
|
@ -165,13 +165,13 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
dshape, dtype)
|
dshape, dtype)
|
||||||
.value();
|
.value();
|
||||||
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
auto w = intType.getWidth();
|
auto width = intType.getWidth();
|
||||||
if (w != 1 && w != 32 && w != 64)
|
if (width != 1 && width != 8 && width != 32 && width != 64)
|
||||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||||
diag << "Unsupported integer type: " << intType;
|
diag << "Unsupported integer type: " << intType;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (w == 1) {
|
if (width == 1) {
|
||||||
if (!isInValidRange<bool>(isFloat, doubleValue, isInt, intValue)) {
|
if (!isInValidRange<bool>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Supplied value of scalar constant exceeds limits "
|
op, "Supplied value of scalar constant exceeds limits "
|
||||||
|
@ -182,7 +182,18 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
tosaTensor = tosa::getConstTensor<bool>(
|
tosaTensor = tosa::getConstTensor<bool>(
|
||||||
rewriter, op, SmallVector<bool>(numElem, d), dshape)
|
rewriter, op, SmallVector<bool>(numElem, d), dshape)
|
||||||
.value();
|
.value();
|
||||||
} else if (w == 32) {
|
} else if (width == 8) {
|
||||||
|
if (!isInValidRange<int8_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Supplied value of scalar constant exceeds limits "
|
||||||
|
"of destination type");
|
||||||
|
}
|
||||||
|
int8_t d = isFloat ? static_cast<int8_t>(doubleValue)
|
||||||
|
: static_cast<int8_t>(intValue);
|
||||||
|
tosaTensor = tosa::getConstTensor<int8_t>(
|
||||||
|
rewriter, op, SmallVector<int8_t>(numElem, d), dshape)
|
||||||
|
.value();
|
||||||
|
} else if (width == 32) {
|
||||||
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Supplied value of scalar constant exceeds limits "
|
op, "Supplied value of scalar constant exceeds limits "
|
||||||
|
@ -193,7 +204,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
tosaTensor = tosa::getConstTensor<int32_t>(
|
tosaTensor = tosa::getConstTensor<int32_t>(
|
||||||
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
|
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
|
||||||
.value();
|
.value();
|
||||||
} else if (w == 64) {
|
} else if (width == 64) {
|
||||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Supplied value of scalar constant exceeds limits "
|
op, "Supplied value of scalar constant exceeds limits "
|
||||||
|
@ -919,13 +930,17 @@ class ConvertAtenMultipleDimsReductionOp
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
ElementsAttr &reduceDimsAttr,
|
ElementsAttr &reduceDimsAttr,
|
||||||
bool &keepDims) const override {
|
bool &keepDims) const override {
|
||||||
SmallVector<int64_t, 4> reduceDims;
|
|
||||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims)))
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const dim parameter unsupported");
|
|
||||||
int64_t N = reduceDims.size();
|
|
||||||
int64_t inputRank =
|
int64_t inputRank =
|
||||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
|
|
||||||
|
SmallVector<int64_t> reduceDims;
|
||||||
|
// If dim list is none, all dimensions are reduced
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims))) {
|
||||||
|
for (int64_t i = 0; i < inputRank; i++)
|
||||||
|
reduceDims.push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t N = reduceDims.size();
|
||||||
for (unsigned i = 0; i < N; i++) {
|
for (unsigned i = 0; i < N; i++) {
|
||||||
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
|
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
|
||||||
if (!isValidDim(reduceDims[i], inputRank))
|
if (!isValidDim(reduceDims[i], inputRank))
|
||||||
|
@ -2895,9 +2910,10 @@ template <>
|
||||||
LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
||||||
AtenThresholdOp op, OpAdaptor adaptor,
|
AtenThresholdOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -2907,12 +2923,9 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point or integer datatype legalization supported");
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
|
|
||||||
// Integer types with width > 32 are not supported
|
auto outType =
|
||||||
auto selfIntType = dyn_cast<IntegerType>(selfElemTy);
|
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
if (selfIntType && selfIntType.getWidth() > 32) {
|
auto outElemTy = outType.getElementType();
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Integer types with width greater than 32 are not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
|
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
|
||||||
Value threshold, value;
|
Value threshold, value;
|
||||||
|
@ -2922,21 +2935,16 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
||||||
op, "Only scalar constant is supported for threshold");
|
op, "Only scalar constant is supported for threshold");
|
||||||
|
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), value,
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), value,
|
||||||
selfElemTy, constTypeShape)))
|
outElemTy, constTypeShape)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only scalar constant is supported for value");
|
op, "Only scalar constant is supported for value");
|
||||||
|
|
||||||
// Threshold only clamps the upper values. tosa::ClampOp has the same
|
|
||||||
// value for both threshold and clamped value so cannot be used.
|
|
||||||
auto outType = getTypeConverter()->convertType(op.getType());
|
|
||||||
|
|
||||||
auto cmpOp = rewriter.create<tosa::GreaterOp>(
|
auto cmpOp = rewriter.create<tosa::GreaterOp>(
|
||||||
op.getLoc(),
|
op.getLoc(),
|
||||||
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
|
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
|
||||||
adaptor.getSelf(), threshold);
|
self, threshold);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp,
|
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp, self, value);
|
||||||
adaptor.getSelf(), value);
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -3660,8 +3668,9 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
AtenBroadcastToOp op, OpAdaptor adaptor,
|
AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
if (!selfType || !selfType.hasStaticShape())
|
if (!selfType || !selfType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types with static shape are supported");
|
op, "Only tensor types with static shape are supported");
|
||||||
|
@ -3675,19 +3684,43 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
SmallVector<int64_t> resultShape;
|
SmallVector<int64_t> resultShape;
|
||||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape)))
|
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"size must consist of Scalar constants");
|
"Size must consist of Scalar constants");
|
||||||
|
|
||||||
|
int64_t inputRank = selfType.getRank();
|
||||||
|
int64_t outputRank = resultShape.size();
|
||||||
|
if (inputRank > outputRank)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Input tensor rank cannot be greater than output tensor rank");
|
||||||
|
|
||||||
// Get the result type
|
// Get the result type
|
||||||
auto resultType = getTypeConverter()->convertType(op.getType());
|
auto resultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
||||||
SmallVector<int64_t> inputShape(
|
SmallVector<int64_t> inputShape(
|
||||||
makeShapeTorchCompatible(selfType.getShape()));
|
makeShapeTorchCompatible(selfType.getShape()));
|
||||||
|
|
||||||
|
// If input rank is smaller than output rank, we reshape the input tensor to
|
||||||
|
// be the same rank as the output tensor by prepending 1s to the input shape
|
||||||
|
SmallVector<int64_t> targetInputShape;
|
||||||
|
for (int64_t i = 0; i < outputRank - inputRank; i++)
|
||||||
|
targetInputShape.push_back(1);
|
||||||
|
targetInputShape.append(inputShape);
|
||||||
|
|
||||||
// Result dimension -1 means not changing the size of that dimension.
|
// Result dimension -1 means not changing the size of that dimension.
|
||||||
// Adjust it by assigning its inputShape.
|
// Adjust it by assigning its inputShape.
|
||||||
for (auto shape : llvm::enumerate(makeShapeTorchCompatible(inputShape))) {
|
for (auto shape :
|
||||||
|
llvm::enumerate(makeShapeTorchCompatible(targetInputShape))) {
|
||||||
auto index = shape.index();
|
auto index = shape.index();
|
||||||
if (resultShape[index] == -1)
|
if (resultShape[index] == -1)
|
||||||
resultShape[index] = shape.value();
|
resultShape[index] = shape.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < outputRank; i++) {
|
||||||
|
if (targetInputShape[i] != resultShape[i] && targetInputShape[i] != 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Input and result shapes should be equal at each dimension or "
|
||||||
|
"input shape should be 1");
|
||||||
|
}
|
||||||
|
|
||||||
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
|
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
|
||||||
// true then we can replace the op result with the input operand directly.
|
// true then we can replace the op result with the input operand directly.
|
||||||
if (llvm::equal(inputShape, resultShape)) {
|
if (llvm::equal(inputShape, resultShape)) {
|
||||||
|
@ -3695,52 +3728,40 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
// since the input and result are of same shape.
|
// since the input and result are of same shape.
|
||||||
op.replaceAllUsesWith(op.getSelf());
|
op.replaceAllUsesWith(op.getSelf());
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
} else {
|
||||||
|
// By using reshape and tile ops, support for input rank smaller than result
|
||||||
|
// rank is allowed. If the rank is smaller, we reshape the input to be the
|
||||||
|
// same rank as the result, then use tile to expand it. The way it was
|
||||||
|
// handled before involves adding the input tensor to a const zero tensor of
|
||||||
|
// output shape to utilize the innate broadcast feature of the TOSA add op.
|
||||||
|
// That poses the danger of sign bit flips for denormalized values.
|
||||||
|
// Basically, this approach to broadcast_to legalization allows for more
|
||||||
|
// flexibility in rank differences and also offers more safety.
|
||||||
|
Value reshapedInput = self;
|
||||||
|
if (!llvm::equal(inputShape, targetInputShape))
|
||||||
|
reshapedInput = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeTorchCompatible(targetInputShape),
|
||||||
|
selfElemTy),
|
||||||
|
self, rewriter.getDenseI64ArrayAttr(targetInputShape));
|
||||||
|
|
||||||
|
SmallVector<int64_t> tileOpShape;
|
||||||
|
for (int64_t i = 0; i < outputRank; i++) {
|
||||||
|
if (targetInputShape[i] == 1) {
|
||||||
|
tileOpShape.push_back(resultShape[i]);
|
||||||
|
} else {
|
||||||
|
tileOpShape.push_back(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = rewriter.create<tosa::TileOp>(
|
||||||
|
op->getLoc(), resultType, reshapedInput,
|
||||||
|
rewriter.getDenseI64ArrayAttr(tileOpShape));
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {result.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
} else if (selfType.hasRank() &&
|
|
||||||
(selfType.getRank() == (int64_t)resultShape.size() ||
|
|
||||||
selfType.getRank() == 0)) {
|
|
||||||
// Right now to support limited cases where input and result shape are not
|
|
||||||
// equal, we can put a constraint that either the input should be of rank
|
|
||||||
// 0 or the rank of input tensor and result should be equal. And then we
|
|
||||||
// can check for broadcasting compatibility for the latter case. For
|
|
||||||
// broadcasting compatibility, either the shape of input and result should
|
|
||||||
// be equal at each dimenion or one of them should be 1.
|
|
||||||
if (selfType.getRank() != 0) {
|
|
||||||
for (unsigned i = 0; i < inputShape.size(); i++) {
|
|
||||||
if (inputShape[i] != resultShape[i] && inputShape[i] != 1 &&
|
|
||||||
resultShape[i] != 1) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: either the shape of input and result should "
|
|
||||||
"be equal at each dimenion or one of them should be 1.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the above condition hold true then we can directly create a const
|
|
||||||
// zero tensor of shape same as the result shape.
|
|
||||||
SmallVector<int64_t> zeroTensorShape{resultShape};
|
|
||||||
|
|
||||||
// create the 0 constant tensor
|
|
||||||
int64_t totalNumElements = 1;
|
|
||||||
for (auto dimSize : zeroTensorShape) {
|
|
||||||
totalNumElements = dimSize * totalNumElements;
|
|
||||||
}
|
|
||||||
// There is some danger here. For edge cases in floating point, x + 0 != x.
|
|
||||||
// The cases are denormalized values, which may get flushed, and -0 + 0 =
|
|
||||||
// +0. (sign bit flips). These are probably acceptable in the short term,
|
|
||||||
// but we should put a comment acknowledging the danger, as there isn't an
|
|
||||||
// op that avoids the denorm flushing.
|
|
||||||
Value zeroTensor =
|
|
||||||
tosa::getZerosLikeTensor(rewriter, op, resultType).value();
|
|
||||||
|
|
||||||
// Use add broadcast
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
|
|
||||||
zeroTensor);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op,
|
|
||||||
"unimplemented: broadcasts other than same rank or zero ranked tensor.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -3843,6 +3864,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto index = adaptor.getIndex();
|
auto index = adaptor.getIndex();
|
||||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||||
|
auto indexShape = indexType.getShape();
|
||||||
|
|
||||||
if (!indexType)
|
if (!indexType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -3851,9 +3873,13 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
int inputRank = inputType.getRank();
|
int inputRank = inputType.getRank();
|
||||||
|
|
||||||
if (indexType.getRank() == 0)
|
if (indexType.getRank() == 0) {
|
||||||
return rewriter.notifyMatchFailure(
|
indexShape = makeShapeTorchCompatible({1});
|
||||||
op, "Rank 0 index tensor is currently not supported");
|
index = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(indexShape, indexType.getElementType()), index,
|
||||||
|
rewriter.getDenseI64ArrayAttr(indexShape));
|
||||||
|
}
|
||||||
|
|
||||||
// Dynamic shape check
|
// Dynamic shape check
|
||||||
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
|
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
|
||||||
|
@ -3865,9 +3891,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||||
index = rewriter.create<tosa::CastOp>(
|
index = rewriter.create<tosa::CastOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(indexType.getShape(),
|
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
|
||||||
rewriter.getIntegerType(32)),
|
|
||||||
index);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get positive dim
|
// Get positive dim
|
||||||
|
@ -3896,7 +3920,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||||
SmallVector<int64_t> indicesInputRankShape;
|
SmallVector<int64_t> indicesInputRankShape;
|
||||||
for (int64_t i = 0; i < inputRank; i++) {
|
for (int64_t i = 0; i < inputRank; i++) {
|
||||||
if (i == dim) {
|
if (i == dim) {
|
||||||
indicesInputRankShape.push_back(indexType.getShape()[0]);
|
indicesInputRankShape.push_back(indexShape[0]);
|
||||||
} else {
|
} else {
|
||||||
indicesInputRankShape.push_back(1);
|
indicesInputRankShape.push_back(1);
|
||||||
}
|
}
|
||||||
|
@ -3952,49 +3976,41 @@ template <>
|
||||||
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
// a = torch.tensor([[0, 1, 2, 3]])
|
|
||||||
// a[..., 1:] = torch.tensor([4, 5, 6])
|
|
||||||
// = a[..., 1:4] = torch.tensor([4, 5, 6])
|
|
||||||
// = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5,
|
|
||||||
// 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
|
|
||||||
// (torch.tensor([0, 0, 0]), torch.tensor([1, 2,
|
|
||||||
// 3])), # indicies torch.tensor([4, 5, 6])) #
|
|
||||||
// value
|
|
||||||
// = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
|
|
||||||
// (None, torch.tensor([1, 2, 3]),),# indicies
|
|
||||||
// torch.tensor([4, 5, 6])) # value
|
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
auto selfType = dyn_cast<TensorType>(input.getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
|
|
||||||
auto fillValues = adaptor.getValues();
|
auto fillValues = adaptor.getValues();
|
||||||
auto valuesType = dyn_cast<TensorType>(adaptor.getValues().getType());
|
auto valuesType = dyn_cast<TensorType>(fillValues.getType());
|
||||||
if (!valuesType)
|
if (!valuesType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
|
|
||||||
// Deal with torch.prim.ListConstruct of non const value to get the index
|
// Deal with torch.prim.ListConstruct of non const value to get the index
|
||||||
|
// Index_put-like ops are now decomposed to aten.index_put.hacked_twin with
|
||||||
|
// stricter semantics, i.e., no None index in indices argument.
|
||||||
auto tensorList = op.getIndices();
|
auto tensorList = op.getIndices();
|
||||||
SmallVector<Value> tensorsTorchType;
|
SmallVector<Value> tensorsTorchType;
|
||||||
if (!getListConstructElements(tensorList, tensorsTorchType))
|
if (!getListConstructElements(tensorList, tensorsTorchType))
|
||||||
return op.emitError(
|
return op.emitError("Tensor list is not from list construct");
|
||||||
"unimplemented: the tensor list is not from list construct");
|
|
||||||
auto indexTensors = getTypeConvertedValues(
|
auto indexTensors = getTypeConvertedValues(
|
||||||
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
|
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
|
||||||
|
|
||||||
auto outType = getTypeConverter()->convertType(op.getType());
|
auto outType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
||||||
// convert list of indices with none into indices tensor without none
|
bool accumulate{false};
|
||||||
// indexTensors (none,[1,2,3]) -> ([0,0,0],[1,2,3])
|
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
|
||||||
// ([[0],[0],[0]],[[1],[2],[3]])-> [[0,1],[0,2], [0,3]]
|
|
||||||
if (indexTensors.size() <= 1) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only support indexput with multiple index.");
|
op, "Accumulate is not a constant bool value");
|
||||||
}
|
|
||||||
|
// No support for accumulate mode yet
|
||||||
|
if (accumulate)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Accumulate mode is not currently supported");
|
||||||
|
|
||||||
SmallVector<Value> indicesTfConcatTensors;
|
SmallVector<Value> indicesTfConcatTensors;
|
||||||
SmallVector<int64_t> indexesRank;
|
SmallVector<int64_t> indexesRank;
|
||||||
SmallVector<SmallVector<int64_t>> indexesShape;
|
SmallVector<SmallVector<int64_t>> indexesShape;
|
||||||
|
@ -4002,28 +4018,6 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
// concat index tensor into to indices tensor for concat
|
// concat index tensor into to indices tensor for concat
|
||||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||||
auto index = indexTensors[i];
|
auto index = indexTensors[i];
|
||||||
auto indexTorch = tensorsTorchType[i];
|
|
||||||
// TODO add support for none index other than i==0, like (index0, None)
|
|
||||||
// (None, index1)
|
|
||||||
if (i == 0 && isa<Torch::NoneType>(indexTorch.getType())) {
|
|
||||||
// convert None to [0,0,0]
|
|
||||||
auto indexNext = indexTensors[i + 1];
|
|
||||||
auto indexNextTorch = tensorsTorchType[i + 1];
|
|
||||||
if (isa<Torch::NoneType>(indexNextTorch.getType())) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Multiple None index is not support for now.");
|
|
||||||
}
|
|
||||||
auto indexNextType = dyn_cast<RankedTensorType>(indexNext.getType());
|
|
||||||
auto indexNextShape = indexNextType.getShape();
|
|
||||||
|
|
||||||
int64_t size = 1;
|
|
||||||
for (auto s : indexNextShape)
|
|
||||||
size *= s;
|
|
||||||
SmallVector<int32_t> values(size, i);
|
|
||||||
index =
|
|
||||||
tosa::getConstTensor<int32_t>(rewriter, op, values, indexNextShape)
|
|
||||||
.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||||
auto indexShape = indexType.getShape();
|
auto indexShape = indexType.getShape();
|
||||||
|
@ -4031,20 +4025,19 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
indexesRank.push_back(indexType.getRank());
|
indexesRank.push_back(indexType.getRank());
|
||||||
|
|
||||||
// index i64 to i32 for tosa compatible
|
// index i64 to i32 for tosa compatible
|
||||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
if (indexType.getElementType() != rewriter.getIntegerType(32))
|
||||||
index = rewriter.create<tosa::CastOp>(
|
index = rewriter.create<tosa::CastOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
|
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
|
||||||
index);
|
index);
|
||||||
}
|
|
||||||
|
|
||||||
// Expand last dim of index to tf indices [3] -> [3,1]
|
// Expand last dim of index to tf indices [3] -> [3,1]
|
||||||
// convert [0,0,0] to [[0],[0],[0]]
|
// convert [0,0,0] to [[0],[0],[0]]
|
||||||
SmallVector<int64_t> indiceShapeOneDim;
|
SmallVector<int64_t> indiceShapeOneDim;
|
||||||
for (auto shape : indexShape) {
|
for (auto shape : indexShape)
|
||||||
indiceShapeOneDim.push_back(shape);
|
indiceShapeOneDim.push_back(shape);
|
||||||
}
|
|
||||||
indiceShapeOneDim.push_back(1);
|
indiceShapeOneDim.push_back(1);
|
||||||
|
|
||||||
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, op->getLoc(),
|
rewriter, op->getLoc(),
|
||||||
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
|
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
|
||||||
|
@ -4061,7 +4054,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
for (auto indexShapeOneDim : indexesShape) {
|
for (auto indexShapeOneDim : indexesShape) {
|
||||||
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
|
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: Only support multi indexes with same shape");
|
op, "Only support indices with same shape");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4075,19 +4068,16 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
|
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
|
||||||
indicesTfConcatTensors, lastDim);
|
indicesTfConcatTensors, lastDim);
|
||||||
|
|
||||||
if (!indicesTf) {
|
if (!indicesTf)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(
|
||||||
"Convert TorchIndex To TfIndices fail.");
|
op, "Convert PyTorch index to TensorFlow indices failed");
|
||||||
}
|
|
||||||
// do the tf scatterNd algorithm with tf style indices as input, algorithm
|
|
||||||
// mostly take from convertGatherNdOp.
|
|
||||||
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
|
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
|
||||||
indicesTf.getResult(), fillValues);
|
indicesTf.getResult(), fillValues);
|
||||||
|
|
||||||
if (!result) {
|
if (!result)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
|
||||||
op, "Convert ScatterNdOp fail for index tensor.");
|
|
||||||
}
|
|
||||||
rewriter.replaceOp(op, {result.value()});
|
rewriter.replaceOp(op, {result.value()});
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -6632,6 +6622,140 @@ LogicalResult ConvertAtenOp<AtenDiagEmbedOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.uniform
|
||||||
|
// Since TOSA hasn't got a built-in random generator yet, we will use
|
||||||
|
// std::uniform_real_distribution with the std::default_random_engine from C++
|
||||||
|
// <random> library
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
|
||||||
|
AtenUniformOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a tensor type
|
||||||
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
|
auto selfShape = selfType.getShape();
|
||||||
|
|
||||||
|
auto generator = adaptor.getGenerator();
|
||||||
|
if (!isa<Torch::NoneType>(generator.getType()))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Custom generators are not supported");
|
||||||
|
|
||||||
|
double fromDouble{0.0}, toDouble{1.0};
|
||||||
|
auto isFloat =
|
||||||
|
matchPattern(op.getFrom(), m_TorchConstantFloat(&fromDouble)) &&
|
||||||
|
matchPattern(op.getTo(), m_TorchConstantFloat(&toDouble));
|
||||||
|
|
||||||
|
int64_t fromInt{0}, toInt{1};
|
||||||
|
auto isInt = matchPattern(op.getFrom(), m_TorchConstantInt(&fromInt)) &&
|
||||||
|
matchPattern(op.getTo(), m_TorchConstantInt(&toInt));
|
||||||
|
|
||||||
|
if (!isFloat && !isInt)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "From and To values are not constant values");
|
||||||
|
|
||||||
|
int64_t numElem = 1;
|
||||||
|
for (int64_t i = 0; i < selfType.getRank(); i++)
|
||||||
|
numElem *= selfShape[i];
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
|
std::default_random_engine gen;
|
||||||
|
|
||||||
|
auto from = isFloat ? fromDouble : fromInt;
|
||||||
|
auto to = isFloat ? toDouble : toInt;
|
||||||
|
|
||||||
|
std::uniform_real_distribution<float> uniformDist(from, to);
|
||||||
|
SmallVector<float> uniformVec;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < numElem; i++)
|
||||||
|
uniformVec.push_back(uniformDist(gen));
|
||||||
|
|
||||||
|
auto result = tosa::getConstTensor<float>(rewriter, op, uniformVec, selfShape,
|
||||||
|
selfType.getElementType())
|
||||||
|
.value();
|
||||||
|
|
||||||
|
result = tosa::promoteType(rewriter, result, resultType);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {result});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.threshold_backward
|
||||||
|
// result = self <= threshold ? 0 : grad
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenThresholdBackwardOp>::matchAndRewrite(
|
||||||
|
AtenThresholdBackwardOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a tensor type
|
||||||
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
|
||||||
|
auto selfShape = selfType.getShape();
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
auto resultElemTy = resultType.getElementType();
|
||||||
|
|
||||||
|
Value threshold;
|
||||||
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.getThreshold(), threshold,
|
||||||
|
selfElemTy, selfShape)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Threshold must be a constant scalar");
|
||||||
|
|
||||||
|
auto grad = adaptor.getGradOutput();
|
||||||
|
|
||||||
|
// Not a tensor type
|
||||||
|
auto gradType = dyn_cast<TensorType>(grad.getType());
|
||||||
|
if (!gradType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
|
Value zero =
|
||||||
|
TypeSwitch<Type, Value>(resultElemTy)
|
||||||
|
.Case<mlir::FloatType>([&](auto) {
|
||||||
|
return tosa::getConstTensor<float>(rewriter, op, 0, {},
|
||||||
|
resultElemTy)
|
||||||
|
.value();
|
||||||
|
})
|
||||||
|
.Case<mlir::IntegerType>([&](auto intType) {
|
||||||
|
switch (intType.getWidth()) {
|
||||||
|
case 1:
|
||||||
|
return tosa::getConstTensor<bool>(rewriter, op, 0, {}).value();
|
||||||
|
case 8:
|
||||||
|
return tosa::getConstTensor<int8_t>(rewriter, op, 0, {}).value();
|
||||||
|
case 32:
|
||||||
|
return tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
|
||||||
|
case 64:
|
||||||
|
return tosa::getConstTensor<int64_t>(rewriter, op, 0, {}).value();
|
||||||
|
}
|
||||||
|
llvm_unreachable("Invalid integer width");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check: input <= threshold
|
||||||
|
auto cond = rewriter.create<tosa::GreaterEqualOp>(
|
||||||
|
op->getLoc(), RankedTensorType::get(selfShape, rewriter.getI1Type()),
|
||||||
|
threshold, self);
|
||||||
|
|
||||||
|
self = tosa::promoteType(rewriter, self, resultType);
|
||||||
|
grad = tosa::promoteType(rewriter, grad, resultType);
|
||||||
|
|
||||||
|
auto result = rewriter.create<tosa::SelectOp>(op->getLoc(), resultType,
|
||||||
|
cond.getResult(), zero, grad);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {result.getResult()});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -6705,6 +6829,7 @@ public:
|
||||||
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
||||||
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
|
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
|
||||||
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
|
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
|
||||||
|
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
|
||||||
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
|
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
|
||||||
tosa::LogicalLeftShiftOp)
|
tosa::LogicalLeftShiftOp)
|
||||||
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
|
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
|
||||||
|
@ -6947,6 +7072,8 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
|
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
|
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -1707,9 +1707,17 @@ TOSA_CRASHING_SET = {
|
||||||
"ScatterSrcStaticModule_basic",
|
"ScatterSrcStaticModule_basic",
|
||||||
# Runtime op verification: Out of bounds access
|
# Runtime op verification: Out of bounds access
|
||||||
"ReduceAllDimEmpty_basic",
|
"ReduceAllDimEmpty_basic",
|
||||||
|
# SmallVector unable to grow for ThresholdBackward1d
|
||||||
|
"ThresholdBackward1dFloatModule_basic",
|
||||||
|
"ThresholdBackward1dIntModule_basic",
|
||||||
|
"ThresholdBackward1dMixedModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_CRASHING_SET = {
|
FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
|
"GridSamplerBasic1_basic",
|
||||||
|
"GridSamplerBasic2_basic",
|
||||||
|
"GridSamplerBasic3_basic",
|
||||||
|
"GridSamplerBasic4_basic",
|
||||||
"ScatterSrcModule_basic",
|
"ScatterSrcModule_basic",
|
||||||
"ScatterSrcStaticModule_basic",
|
"ScatterSrcStaticModule_basic",
|
||||||
"HBC_basic",
|
"HBC_basic",
|
||||||
|
@ -1727,6 +1735,25 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"CosineSimilarityStaticBroadcastModule_basic",
|
||||||
|
"DropoutTrainStaticShapeModule_basic",
|
||||||
|
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||||
|
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
||||||
|
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||||
|
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||||
|
"ElementwiseRreluTrainStaticModule_basic",
|
||||||
|
"IndexSelectRank0IdxModule_basic",
|
||||||
|
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||||
|
"NativeDropoutTrainStaticShapeModule_basic",
|
||||||
|
"RandIntDtypeModule_basic",
|
||||||
|
"RandIntLowDtypeModule_basic",
|
||||||
|
"RandModule_basic",
|
||||||
|
"ReduceL3NormAllDimsModule_basic",
|
||||||
|
"ReduceL3NormKeepDimModule_basic",
|
||||||
|
"SliceCopy_Module_basic",
|
||||||
|
"Threshold1dIntModule_basic",
|
||||||
|
"Threshold2dIntModule_basic",
|
||||||
|
"Threshold3dIntModule_basic",
|
||||||
"EmptyModule_contiguous",
|
"EmptyModule_contiguous",
|
||||||
"EmptyModule_defaultDtype",
|
"EmptyModule_defaultDtype",
|
||||||
"EmptyModule_falsePinMemory",
|
"EmptyModule_falsePinMemory",
|
||||||
|
@ -2296,8 +2323,6 @@ TOSA_PASS_SET = {
|
||||||
"TensorIntModule_basic",
|
"TensorIntModule_basic",
|
||||||
"TensorLiteralModule_basic",
|
"TensorLiteralModule_basic",
|
||||||
"TensorOpaqueLiteralModule_basic",
|
"TensorOpaqueLiteralModule_basic",
|
||||||
"TensorsConcatNegativeDimStaticModule_basic",
|
|
||||||
"TensorsConcatStaticModule_basic",
|
|
||||||
"TestF16Return_basic",
|
"TestF16Return_basic",
|
||||||
"TestMultipleTensorReturn_basic",
|
"TestMultipleTensorReturn_basic",
|
||||||
"Threshold1dFloatModule_basic",
|
"Threshold1dFloatModule_basic",
|
||||||
|
@ -2363,7 +2388,6 @@ TOSA_PASS_SET = {
|
||||||
"LinspaceModule_basic",
|
"LinspaceModule_basic",
|
||||||
"LinspaceOneSizeModule_basic",
|
"LinspaceOneSizeModule_basic",
|
||||||
"LinspaceTwoSizeModule_basic",
|
"LinspaceTwoSizeModule_basic",
|
||||||
"TorchPrimLoopForLikeTensorArgModule_basic",
|
|
||||||
"RenormModuleFloat32NegativeDim_basic",
|
"RenormModuleFloat32NegativeDim_basic",
|
||||||
"RenormModuleFloat32_basic",
|
"RenormModuleFloat32_basic",
|
||||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||||
|
@ -2468,7 +2492,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
"SplitWithSizesListUnpackModule_basic",
|
"SplitWithSizesListUnpackModule_basic",
|
||||||
# Dynamic shape, has extra unsupported broadcast ops
|
# Dynamic shape, has extra unsupported broadcast ops
|
||||||
"Matmul_3d",
|
"Matmul_3d",
|
||||||
"MatmulStaticBroadcast_basic",
|
|
||||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||||
|
@ -2487,7 +2510,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
"ElementwiseLogSigmoidModule_basic",
|
"ElementwiseLogSigmoidModule_basic",
|
||||||
# failed to legalize operation 'torch.aten.rrelu_with_noise'
|
# failed to legalize operation 'torch.aten.rrelu_with_noise'
|
||||||
"ElementwiseRreluEvalModule_basic",
|
"ElementwiseRreluEvalModule_basic",
|
||||||
"ElementwiseRreluEvalStaticModule_basic",
|
|
||||||
# incompatible return type failure for tosa.concat.
|
# incompatible return type failure for tosa.concat.
|
||||||
"HstackBasicComplexModule_basic",
|
"HstackBasicComplexModule_basic",
|
||||||
"HstackBasicFloatModule_basic",
|
"HstackBasicFloatModule_basic",
|
||||||
|
@ -3329,6 +3351,14 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
|
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||||
|
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||||
|
"MaxPool3dLargeDatadModule_basic",
|
||||||
|
"MaxPool3dModuleRandomSimple_basic",
|
||||||
|
"MaxPool3dModule_basic",
|
||||||
|
"MaxPool3dStaticModule_basic",
|
||||||
"ViewDtypeStaticModule_basic",
|
"ViewDtypeStaticModule_basic",
|
||||||
"Unfold_Module_Dynamic_basic",
|
"Unfold_Module_Dynamic_basic",
|
||||||
"Unfold_Module_Rank_4",
|
"Unfold_Module_Rank_4",
|
||||||
|
@ -3474,7 +3504,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"BoolIntFalseModule_basic",
|
"BoolIntFalseModule_basic",
|
||||||
"BoolIntTrueModule_basic",
|
"BoolIntTrueModule_basic",
|
||||||
"BroadcastDynamicDimModule_basic",
|
"BroadcastDynamicDimModule_basic",
|
||||||
"BroadcastToModule_basic",
|
|
||||||
"CeilFloatModule_basic",
|
"CeilFloatModule_basic",
|
||||||
"CollapseAllDimensionsModule_basic",
|
"CollapseAllDimensionsModule_basic",
|
||||||
"CollapseFullDynamicModule_basic",
|
"CollapseFullDynamicModule_basic",
|
||||||
|
@ -3509,7 +3538,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ConvolutionModule2DTransposeStrided_basic",
|
"ConvolutionModule2DTransposeStrided_basic",
|
||||||
"ConvolutionModule2DTranspose_basic",
|
"ConvolutionModule2DTranspose_basic",
|
||||||
"CopyWithDifferentDTypesModule_basic",
|
"CopyWithDifferentDTypesModule_basic",
|
||||||
"CosineSimilarityStaticBroadcastModule_basic",
|
|
||||||
"CumsumInputDtypeInt32Module_basic",
|
"CumsumInputDtypeInt32Module_basic",
|
||||||
"CumsumModule_basic",
|
"CumsumModule_basic",
|
||||||
"CumsumStaticModule_basic",
|
"CumsumStaticModule_basic",
|
||||||
|
@ -3524,8 +3552,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"DeterminantModule_F32",
|
"DeterminantModule_F32",
|
||||||
"DivFloatModule_basic",
|
"DivFloatModule_basic",
|
||||||
"DivIntModule_basic",
|
"DivIntModule_basic",
|
||||||
"DropoutTrainModule_basic",
|
|
||||||
"DropoutTrainStaticShapeModule_basic",
|
|
||||||
"ElementwiseAcosIntModule_basic",
|
"ElementwiseAcosIntModule_basic",
|
||||||
"ElementwiseAcosModule_basic",
|
"ElementwiseAcosModule_basic",
|
||||||
"ElementwiseAcoshIntModule_basic",
|
"ElementwiseAcoshIntModule_basic",
|
||||||
|
@ -3545,11 +3571,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAtanTensorIntModule_basic",
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
"ElementwiseAtanhIntModule_basic",
|
"ElementwiseAtanhIntModule_basic",
|
||||||
"ElementwiseAtanhModule_basic",
|
"ElementwiseAtanhModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
|
||||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
|
||||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
|
||||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
|
||||||
"ElementwiseClampMinTensorFloatModule_basic",
|
"ElementwiseClampMinTensorFloatModule_basic",
|
||||||
"ElementwiseClampMinTensorIntModule_basic",
|
"ElementwiseClampMinTensorIntModule_basic",
|
||||||
"ElementwiseClampTensorFloatModule_basic",
|
"ElementwiseClampTensorFloatModule_basic",
|
||||||
|
@ -3590,12 +3612,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseUnaryIntModule_basic",
|
"ElementwiseUnaryIntModule_basic",
|
||||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||||
"EqIntModule_basic",
|
"EqIntModule_basic",
|
||||||
"ExpandModule_basic",
|
|
||||||
"ExponentialModule_basic",
|
|
||||||
"FloatImplicitModule_basic",
|
"FloatImplicitModule_basic",
|
||||||
"FullLikeModuleInt2D_basic",
|
"FullLikeModuleInt2D_basic",
|
||||||
"FullLikeModuleInt3D_basic",
|
"FullLikeModuleInt3D_basic",
|
||||||
"FullModuleInt2D_basic",
|
|
||||||
"GeFloatIntModule_basic",
|
"GeFloatIntModule_basic",
|
||||||
"GeFloatModule_basic",
|
"GeFloatModule_basic",
|
||||||
"GeIntModule_basic",
|
"GeIntModule_basic",
|
||||||
|
@ -3606,42 +3625,25 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"GtFloatIntModule_basic",
|
"GtFloatIntModule_basic",
|
||||||
"GtIntModule_basic",
|
"GtIntModule_basic",
|
||||||
"IndexPut1DFloatAccumulateModule_basic",
|
"IndexPut1DFloatAccumulateModule_basic",
|
||||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPut1DIntAccumulateModule_basic",
|
"IndexPut1DIntAccumulateModule_basic",
|
||||||
"IndexPut1DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPut2DFloatAccumulateModule_basic",
|
"IndexPut2DFloatAccumulateModule_basic",
|
||||||
"IndexPut2DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPut2DIntAccumulateModule_basic",
|
"IndexPut2DIntAccumulateModule_basic",
|
||||||
"IndexPut2DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPut3DFloatAccumulateModule_basic",
|
"IndexPut3DFloatAccumulateModule_basic",
|
||||||
"IndexPut3DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPut3DIntAccumulateModule_basic",
|
"IndexPut3DIntAccumulateModule_basic",
|
||||||
"IndexPut3DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
|
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
|
||||||
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DIntAccumulateModule_basic",
|
"IndexPutHackedTwin1DIntAccumulateModule_basic",
|
||||||
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
|
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
|
||||||
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DIntAccumulateModule_basic",
|
"IndexPutHackedTwin2DIntAccumulateModule_basic",
|
||||||
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
|
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
|
||||||
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
||||||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutImpl1DFloatAccumulateModule_basic",
|
"IndexPutImpl1DFloatAccumulateModule_basic",
|
||||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutImpl1DIntAccumulateModule_basic",
|
"IndexPutImpl1DIntAccumulateModule_basic",
|
||||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutImpl2DFloatAccumulateModule_basic",
|
"IndexPutImpl2DFloatAccumulateModule_basic",
|
||||||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutImpl2DImplicitModule_basic",
|
"IndexPutImpl2DImplicitModule_basic",
|
||||||
"IndexPutImpl2DIndexModule_basic",
|
"IndexPutImpl2DIndexModule_basic",
|
||||||
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutImplIndexWithNoneModule_basic",
|
"IndexPutImplIndexWithNoneModule_basic",
|
||||||
"IndexSelectRank0IdxModule_basic",
|
|
||||||
"InterpolateDynamicModule_sizes_bilinear",
|
"InterpolateDynamicModule_sizes_bilinear",
|
||||||
"InterpolateDynamicModule_sizes_nearest",
|
"InterpolateDynamicModule_sizes_nearest",
|
||||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||||
|
@ -3656,8 +3658,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"LinspaceDtypeModule_basic",
|
"LinspaceDtypeModule_basic",
|
||||||
"LinspaceEmptyModule_basic",
|
"LinspaceEmptyModule_basic",
|
||||||
"MaskedFillTensorFloatValueModule_basic",
|
"MaskedFillTensorFloatValueModule_basic",
|
||||||
"MatmulBroadcastBatchDim_basic",
|
"MaskedScatterStaticBasic_basic",
|
||||||
"MatmulStaticBroadcast_basic",
|
|
||||||
"MaxPool1dCeilModeTrueModule_basic",
|
"MaxPool1dCeilModeTrueModule_basic",
|
||||||
"MaxPool1dModule_basic",
|
"MaxPool1dModule_basic",
|
||||||
"MaxPool2dCeilModeTrueModule_basic",
|
"MaxPool2dCeilModeTrueModule_basic",
|
||||||
|
@ -3689,17 +3690,16 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
||||||
"MaxPool3dWithIndicesStaticModule_basic",
|
"MaxPool3dWithIndicesStaticModule_basic",
|
||||||
"MeanDimEmptyDimModule_basic",
|
"MeanDimEmptyDimModule_basic",
|
||||||
"MeanDimNoneDimModule_basic",
|
"MlGroupNormManualModule_basic",
|
||||||
"MseLossMeanReductionModule_basic",
|
"MlGroupNormModule_basic",
|
||||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
"MlLayerNormManualModule_basic",
|
||||||
|
"MlLayerNormModule_basic",
|
||||||
"MulFloatModule_basic",
|
"MulFloatModule_basic",
|
||||||
"MulIntModule_basic",
|
"MulIntModule_basic",
|
||||||
"NativeBatchNorm1DModule_basic",
|
"NativeBatchNorm1DModule_basic",
|
||||||
"NativeBatchNorm2DModule_basic",
|
"NativeBatchNorm2DModule_basic",
|
||||||
"NativeBatchNorm3DModule_basic",
|
"NativeBatchNorm3DModule_basic",
|
||||||
"NativeBatchNormNoneWeightModule_basic",
|
"NativeBatchNormNoneWeightModule_basic",
|
||||||
"NativeDropoutTrainModule_basic",
|
|
||||||
"NativeDropoutTrainStaticShapeModule_basic",
|
|
||||||
"NativeGroupNormBackwardModule_basic",
|
"NativeGroupNormBackwardModule_basic",
|
||||||
"NeFloatIntModule_basic",
|
"NeFloatIntModule_basic",
|
||||||
"NeIntModule_basic",
|
"NeIntModule_basic",
|
||||||
|
@ -3741,14 +3741,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"QuantizedReluInt8_basic",
|
"QuantizedReluInt8_basic",
|
||||||
"QuantizedReluUint8_basic",
|
"QuantizedReluUint8_basic",
|
||||||
"QuantizedSingleLayer_basic",
|
"QuantizedSingleLayer_basic",
|
||||||
"RandIntDtypeModule_basic",
|
|
||||||
"RandIntLowDtypeModule_basic",
|
|
||||||
"RandIntLowModule_basic",
|
"RandIntLowModule_basic",
|
||||||
"RandIntModule_basic",
|
"RandIntModule_basic",
|
||||||
"RandIntPinMemoryModule_basic",
|
"RandIntPinMemoryModule_basic",
|
||||||
"RandLikeDtypeModule_basic",
|
|
||||||
"RandLikeModule_basic",
|
|
||||||
"RandModule_basic",
|
|
||||||
"RandnDtypeDeviceModule_basic",
|
"RandnDtypeDeviceModule_basic",
|
||||||
"RandnGeneratorF64Module_basic",
|
"RandnGeneratorF64Module_basic",
|
||||||
"RandnGeneratorModule_basic",
|
"RandnGeneratorModule_basic",
|
||||||
|
@ -3760,9 +3755,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ReduceL1NormComplexModule_basic",
|
"ReduceL1NormComplexModule_basic",
|
||||||
"ReduceL1NormWithDTypeModule_basic",
|
"ReduceL1NormWithDTypeModule_basic",
|
||||||
"ReduceL2NormComplexModule_basic",
|
"ReduceL2NormComplexModule_basic",
|
||||||
"ReduceL3NormAllDimsModule_basic",
|
|
||||||
"ReduceL3NormKeepDimComplexModule_basic",
|
"ReduceL3NormKeepDimComplexModule_basic",
|
||||||
"ReduceL3NormKeepDimModule_basic",
|
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
"ReduceMinAlongDimUnsignedInt_basic",
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
"ReduceSumDimIntListEmptyDimModule_basic",
|
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||||
|
@ -3843,18 +3836,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"TensorsConcatPromoteDTypeModule_basic",
|
"TensorsConcatPromoteDTypeModule_basic",
|
||||||
"TensorsStackPromoteDTypeModule_basic",
|
"TensorsStackPromoteDTypeModule_basic",
|
||||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||||
"Threshold1dIntModule_basic",
|
|
||||||
"Threshold2dIntModule_basic",
|
|
||||||
"Threshold3dIntModule_basic",
|
|
||||||
"ThresholdBackward1dFloatModule_basic",
|
|
||||||
"ThresholdBackward1dIntModule_basic",
|
|
||||||
"ThresholdBackward1dMixedModule_basic",
|
|
||||||
"ThresholdBackward2dFloatModule_basic",
|
|
||||||
"ThresholdBackward2dIntModule_basic",
|
|
||||||
"ThresholdBackward2dMixedModule_basic",
|
"ThresholdBackward2dMixedModule_basic",
|
||||||
"ThresholdBackward3dFloatModule_basic",
|
|
||||||
"ThresholdBackward3dIntModule_basic",
|
|
||||||
"ThresholdBackward3dMixedModule_basic",
|
|
||||||
"ToCopyWithDTypeFalsePinMemoryModule_basic",
|
"ToCopyWithDTypeFalsePinMemoryModule_basic",
|
||||||
"ToCopyWithDTypeModule_basic",
|
"ToCopyWithDTypeModule_basic",
|
||||||
"TorchPrimLoopForLikeModule_basic",
|
"TorchPrimLoopForLikeModule_basic",
|
||||||
|
@ -3863,10 +3845,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"TraceUnsignedIntModule_empty",
|
"TraceUnsignedIntModule_empty",
|
||||||
"TypeConversionI1ToF64Module_basic",
|
"TypeConversionI1ToF64Module_basic",
|
||||||
"TypeConversionI1ToI32Module_basic",
|
"TypeConversionI1ToI32Module_basic",
|
||||||
"UniformModule_basic",
|
|
||||||
"UniformNoCorrelationModule_basic",
|
|
||||||
"UniformStaticShapeModule_basic",
|
|
||||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
|
||||||
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"UpSampleNearest2dBackwardScalesNone_basic",
|
"UpSampleNearest2dBackwardScalesNone_basic",
|
||||||
"UpSampleNearest2dBackward_basic",
|
"UpSampleNearest2dBackward_basic",
|
||||||
|
@ -3875,9 +3853,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"UpSampleNearest2dStaticFactor_basic",
|
"UpSampleNearest2dStaticFactor_basic",
|
||||||
"UpSampleNearest2dStaticSize_basic",
|
"UpSampleNearest2dStaticSize_basic",
|
||||||
"UpSampleNearest2d_basic",
|
"UpSampleNearest2d_basic",
|
||||||
"VarMeanBiasedModule_basic",
|
|
||||||
"VarMeanCorrectionNoneModule_basic",
|
|
||||||
"VarMeanUnbiasedModule_basic",
|
|
||||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
"VisionTransformerModule_basic",
|
"VisionTransformerModule_basic",
|
||||||
|
@ -3894,6 +3869,15 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"ElementwiseRreluWithNoiseEvalModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||||
|
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||||
|
"RreluWithNoiseBackwardEvalModule_basic",
|
||||||
|
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||||
|
"RreluWithNoiseBackwardTrainModule_basic",
|
||||||
|
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||||
|
"RreluWithNoiseForwardBackwardModule_basic",
|
||||||
"Unfold_Module_Dynamic_basic",
|
"Unfold_Module_Dynamic_basic",
|
||||||
"Unfold_Module_Rank_4",
|
"Unfold_Module_Rank_4",
|
||||||
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||||
|
@ -3937,12 +3921,10 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"Conv_Transpose2dStaticModule_basic",
|
"Conv_Transpose2dStaticModule_basic",
|
||||||
"Conv_Transpose3dModule_basic",
|
"Conv_Transpose3dModule_basic",
|
||||||
"Conv_Transpose3dStaticModule_basic",
|
"Conv_Transpose3dStaticModule_basic",
|
||||||
"EinsumStaticModule_basic",
|
|
||||||
"ElementwiseFmaxModule_basic",
|
"ElementwiseFmaxModule_basic",
|
||||||
"ElementwiseFminModule_basic",
|
"ElementwiseFminModule_basic",
|
||||||
"ElementwiseGeluApproximateTanhModule_basic",
|
"ElementwiseGeluApproximateTanhModule_basic",
|
||||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||||
"ElementwiseNanToNumWithNoneModule_Basic",
|
|
||||||
"ElementwiseRad2DegIntModule_basic",
|
"ElementwiseRad2DegIntModule_basic",
|
||||||
"ElementwiseRad2DegModule_basic",
|
"ElementwiseRad2DegModule_basic",
|
||||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||||
|
@ -4106,7 +4088,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"BoolIntConstantModule_basic",
|
"BoolIntConstantModule_basic",
|
||||||
"BoolIntFalseModule_basic",
|
"BoolIntFalseModule_basic",
|
||||||
"BoolIntTrueModule_basic",
|
"BoolIntTrueModule_basic",
|
||||||
"BoolTensorHandleSignless_basic",
|
|
||||||
"BroadcastDynamicDimModule_basic",
|
"BroadcastDynamicDimModule_basic",
|
||||||
"BroadcastToModule_basic",
|
"BroadcastToModule_basic",
|
||||||
"BucketizeTensorFloatModule_basic",
|
"BucketizeTensorFloatModule_basic",
|
||||||
|
@ -4123,10 +4104,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"CollapseRank1DynamicModule_basic",
|
"CollapseRank1DynamicModule_basic",
|
||||||
"CollapseStaticModule_basic",
|
"CollapseStaticModule_basic",
|
||||||
"ConstantBoolParameterModule_basic",
|
"ConstantBoolParameterModule_basic",
|
||||||
"ConstantPad2dStaticModule_basic",
|
|
||||||
"ConstantPadNdModule_basic",
|
|
||||||
"ConstantPadNdPartialStaticModule_basic",
|
|
||||||
"ConstantPadNdStaticModule_basic",
|
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv1dModule_basic",
|
"Conv1dModule_basic",
|
||||||
|
@ -4220,9 +4197,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
||||||
"ElementwiseAtenIsneginfOpModule_basic",
|
"ElementwiseAtenIsneginfOpModule_basic",
|
||||||
"ElementwiseAtenIsposinfOpModule_basic",
|
"ElementwiseAtenIsposinfOpModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
|
||||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
|
||||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||||
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
|
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
|
||||||
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
|
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
|
||||||
|
@ -4254,7 +4229,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseCoshModule_basic",
|
"ElementwiseCoshModule_basic",
|
||||||
"ElementwiseDequantizePerChannelModule_basic",
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
"ElementwiseDequantizePerTensorModule_basic",
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
|
||||||
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
||||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||||
"ElementwiseDivTensorFloatModule_basic",
|
"ElementwiseDivTensorFloatModule_basic",
|
||||||
|
@ -4291,7 +4265,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseMulTensorComplexModule_basic",
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
"ElementwiseMulTensorFloatModule_basic",
|
"ElementwiseMulTensorFloatModule_basic",
|
||||||
"ElementwiseMulTensorIntModule_basic",
|
"ElementwiseMulTensorIntModule_basic",
|
||||||
"ElementwiseNanToNumModule_Basic",
|
|
||||||
"ElementwiseOrTensorModule_basic",
|
"ElementwiseOrTensorModule_basic",
|
||||||
"ElementwiseOrTensorStaticShapeModule_basic",
|
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
|
@ -4579,8 +4552,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"OnesLikeModule_falsePinMemory",
|
"OnesLikeModule_falsePinMemory",
|
||||||
"OnesLikeModule_float",
|
"OnesLikeModule_float",
|
||||||
"OnesLikeModule_int",
|
"OnesLikeModule_int",
|
||||||
"PadModule_basic",
|
|
||||||
"PadWithNoneValModule_basic",
|
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
"PixelShuffleModuleFullDynamic_basic",
|
"PixelShuffleModuleFullDynamic_basic",
|
||||||
"PixelShuffleModuleSpatiallyDynamic_basic",
|
"PixelShuffleModuleSpatiallyDynamic_basic",
|
||||||
|
@ -4688,7 +4659,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ReflectionPad2dModule_Right",
|
"ReflectionPad2dModule_Right",
|
||||||
"ReflectionPad2dModule_Top",
|
"ReflectionPad2dModule_Top",
|
||||||
"ReflectionPad2dModule_basic",
|
"ReflectionPad2dModule_basic",
|
||||||
"RepeatModule_basic",
|
|
||||||
"ReplicationPad2dModule_basic",
|
"ReplicationPad2dModule_basic",
|
||||||
"ReplicationPad2dModule_bottom0",
|
"ReplicationPad2dModule_bottom0",
|
||||||
"ReplicationPad2dModule_left0",
|
"ReplicationPad2dModule_left0",
|
||||||
|
|
|
@ -2162,4 +2162,82 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6
|
||||||
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
|
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
|
||||||
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
|
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
|
||||||
return %1 : !torch.vtensor<[4,2],si64>
|
return %1 : !torch.vtensor<[4,2],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.threshold_backward$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],si64> -> tensor<4xi64>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_2]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<4xi1>, tensor<i64>, tensor<4xi64>) -> tensor<4xi64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<4xi64> -> !torch.vtensor<[4],si64>
|
||||||
|
// CHECK: return %[[VAL_9]] : !torch.vtensor<[4],si64>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.threshold_backward$basic(%arg0: !torch.vtensor<[4],si64>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.threshold_backward %arg0, %arg1, %int1 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
|
||||||
|
return %0 : !torch.vtensor<[4],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.threshold$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],si64> -> tensor<4x5xi64>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.greater %[[VAL_1]], %[[VAL_4]] : (tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi1>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_6]], %[[VAL_1]], %[[VAL_5]] : (tensor<4x5xi1>, tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi64>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4x5xi64> -> !torch.vtensor<[4,5],si64>
|
||||||
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,5],si64>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.threshold$basic(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> {
|
||||||
|
%float5.000000e-01 = torch.constant.float 5.000000e-01
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.threshold %arg0, %float5.000000e-01, %int2 : !torch.vtensor<[4,5],si64>, !torch.float, !torch.int -> !torch.vtensor<[4,5],si64>
|
||||||
|
return %0 : !torch.vtensor<[4,5],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.logical_and$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.logical_and %[[VAL_3]], %[[VAL_2]] : (tensor<4x5xi1>, tensor<4x5xi1>) -> tensor<4x5xi1>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,5],i1>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
|
||||||
|
%0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,5],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.uniform$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64>
|
||||||
|
// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
|
||||||
|
%float1.000000e00 = torch.constant.float 1.000000e+00
|
||||||
|
%float1.000000e01 = torch.constant.float 1.000000e+01
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64>
|
||||||
|
return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue