mirror of https://github.com/llvm/torch-mlir
[TOSA] Expand Torch to TOSA legalization coverage (#3827)
- Add/Extend Torch to TOSA legalization for the following ops: + Add aten.threshold_backward + Fix aten.threshold + Re-implement aten.broadcast_to using tosa.reshape and tosa.tile + Add support for rank 0 index for aten.index_select + Fix aten.index_put.hacked_twin + Add aten.uniform + Add aten.logical_and - Update xfail_sets.py with new e2e results - Add LIT tests to basic.mlir for newly added ops Change-Id: I8910564a049d18293284fe2e55e82bc1d2cf10e3 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3839/head
parent
a6292f38ca
commit
4dd213b042
|
@ -25,6 +25,7 @@
|
|||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -125,15 +126,14 @@ template <typename T>
|
|||
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
||||
const int64_t &intValue) {
|
||||
if (isFloat) {
|
||||
// Do a round-trip check here instead of numeric limits due to
|
||||
// compiler warnings around double <-> int conversion.
|
||||
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue)));
|
||||
} else {
|
||||
assert(isInt);
|
||||
return (doubleValue >=
|
||||
static_cast<double>(std::numeric_limits<T>::min())) &&
|
||||
(doubleValue <= static_cast<double>(std::numeric_limits<T>::max()));
|
||||
} else if (isInt) {
|
||||
return (intValue >= static_cast<int64_t>(std::numeric_limits<T>::min())) &&
|
||||
(intValue <= static_cast<int64_t>(std::numeric_limits<T>::max()));
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// FIXME: This will eventually go into a Tosa*Utils file.
|
||||
|
@ -165,13 +165,13 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
dshape, dtype)
|
||||
.value();
|
||||
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||
auto w = intType.getWidth();
|
||||
if (w != 1 && w != 32 && w != 64)
|
||||
auto width = intType.getWidth();
|
||||
if (width != 1 && width != 8 && width != 32 && width != 64)
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "Unsupported integer type: " << intType;
|
||||
});
|
||||
|
||||
if (w == 1) {
|
||||
if (width == 1) {
|
||||
if (!isInValidRange<bool>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Supplied value of scalar constant exceeds limits "
|
||||
|
@ -182,7 +182,18 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
tosaTensor = tosa::getConstTensor<bool>(
|
||||
rewriter, op, SmallVector<bool>(numElem, d), dshape)
|
||||
.value();
|
||||
} else if (w == 32) {
|
||||
} else if (width == 8) {
|
||||
if (!isInValidRange<int8_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Supplied value of scalar constant exceeds limits "
|
||||
"of destination type");
|
||||
}
|
||||
int8_t d = isFloat ? static_cast<int8_t>(doubleValue)
|
||||
: static_cast<int8_t>(intValue);
|
||||
tosaTensor = tosa::getConstTensor<int8_t>(
|
||||
rewriter, op, SmallVector<int8_t>(numElem, d), dshape)
|
||||
.value();
|
||||
} else if (width == 32) {
|
||||
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Supplied value of scalar constant exceeds limits "
|
||||
|
@ -193,7 +204,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
tosaTensor = tosa::getConstTensor<int32_t>(
|
||||
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
|
||||
.value();
|
||||
} else if (w == 64) {
|
||||
} else if (width == 64) {
|
||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Supplied value of scalar constant exceeds limits "
|
||||
|
@ -919,13 +930,17 @@ class ConvertAtenMultipleDimsReductionOp
|
|||
ConversionPatternRewriter &rewriter,
|
||||
ElementsAttr &reduceDimsAttr,
|
||||
bool &keepDims) const override {
|
||||
SmallVector<int64_t, 4> reduceDims;
|
||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
int64_t N = reduceDims.size();
|
||||
int64_t inputRank =
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
|
||||
SmallVector<int64_t> reduceDims;
|
||||
// If dim list is none, all dimensions are reduced
|
||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims))) {
|
||||
for (int64_t i = 0; i < inputRank; i++)
|
||||
reduceDims.push_back(i);
|
||||
}
|
||||
|
||||
int64_t N = reduceDims.size();
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
|
||||
if (!isValidDim(reduceDims[i], inputRank))
|
||||
|
@ -2895,9 +2910,10 @@ template <>
|
|||
LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
||||
AtenThresholdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -2907,12 +2923,9 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point or integer datatype legalization supported");
|
||||
|
||||
// Integer types with width > 32 are not supported
|
||||
auto selfIntType = dyn_cast<IntegerType>(selfElemTy);
|
||||
if (selfIntType && selfIntType.getWidth() > 32) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Integer types with width greater than 32 are not supported");
|
||||
}
|
||||
auto outType =
|
||||
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
auto outElemTy = outType.getElementType();
|
||||
|
||||
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
|
||||
Value threshold, value;
|
||||
|
@ -2922,21 +2935,16 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
|||
op, "Only scalar constant is supported for threshold");
|
||||
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), value,
|
||||
selfElemTy, constTypeShape)))
|
||||
outElemTy, constTypeShape)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only scalar constant is supported for value");
|
||||
|
||||
// Threshold only clamps the upper values. tosa::ClampOp has the same
|
||||
// value for both threshold and clamped value so cannot be used.
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
auto cmpOp = rewriter.create<tosa::GreaterOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
|
||||
adaptor.getSelf(), threshold);
|
||||
self, threshold);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp,
|
||||
adaptor.getSelf(), value);
|
||||
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp, self, value);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -3660,8 +3668,9 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto self = adaptor.getSelf();
|
||||
// Not a tensor type.
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
@ -3675,19 +3684,43 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
SmallVector<int64_t> resultShape;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"size must consist of Scalar constants");
|
||||
"Size must consist of Scalar constants");
|
||||
|
||||
int64_t inputRank = selfType.getRank();
|
||||
int64_t outputRank = resultShape.size();
|
||||
if (inputRank > outputRank)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input tensor rank cannot be greater than output tensor rank");
|
||||
|
||||
// Get the result type
|
||||
auto resultType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
SmallVector<int64_t> inputShape(
|
||||
makeShapeTorchCompatible(selfType.getShape()));
|
||||
|
||||
// If input rank is smaller than output rank, we reshape the input tensor to
|
||||
// be the same rank as the output tensor by prepending 1s to the input shape
|
||||
SmallVector<int64_t> targetInputShape;
|
||||
for (int64_t i = 0; i < outputRank - inputRank; i++)
|
||||
targetInputShape.push_back(1);
|
||||
targetInputShape.append(inputShape);
|
||||
|
||||
// Result dimension -1 means not changing the size of that dimension.
|
||||
// Adjust it by assigning its inputShape.
|
||||
for (auto shape : llvm::enumerate(makeShapeTorchCompatible(inputShape))) {
|
||||
for (auto shape :
|
||||
llvm::enumerate(makeShapeTorchCompatible(targetInputShape))) {
|
||||
auto index = shape.index();
|
||||
if (resultShape[index] == -1)
|
||||
resultShape[index] = shape.value();
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < outputRank; i++) {
|
||||
if (targetInputShape[i] != resultShape[i] && targetInputShape[i] != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input and result shapes should be equal at each dimension or "
|
||||
"input shape should be 1");
|
||||
}
|
||||
|
||||
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
|
||||
// true then we can replace the op result with the input operand directly.
|
||||
if (llvm::equal(inputShape, resultShape)) {
|
||||
|
@ -3695,52 +3728,40 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
// since the input and result are of same shape.
|
||||
op.replaceAllUsesWith(op.getSelf());
|
||||
rewriter.eraseOp(op);
|
||||
} else {
|
||||
// By using reshape and tile ops, support for input rank smaller than result
|
||||
// rank is allowed. If the rank is smaller, we reshape the input to be the
|
||||
// same rank as the result, then use tile to expand it. The way it was
|
||||
// handled before involves adding the input tensor to a const zero tensor of
|
||||
// output shape to utilize the innate broadcast feature of the TOSA add op.
|
||||
// That poses the danger of sign bit flips for denormalized values.
|
||||
// Basically, this approach to broadcast_to legalization allows for more
|
||||
// flexibility in rank differences and also offers more safety.
|
||||
Value reshapedInput = self;
|
||||
if (!llvm::equal(inputShape, targetInputShape))
|
||||
reshapedInput = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeTorchCompatible(targetInputShape),
|
||||
selfElemTy),
|
||||
self, rewriter.getDenseI64ArrayAttr(targetInputShape));
|
||||
|
||||
SmallVector<int64_t> tileOpShape;
|
||||
for (int64_t i = 0; i < outputRank; i++) {
|
||||
if (targetInputShape[i] == 1) {
|
||||
tileOpShape.push_back(resultShape[i]);
|
||||
} else {
|
||||
tileOpShape.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
auto result = rewriter.create<tosa::TileOp>(
|
||||
op->getLoc(), resultType, reshapedInput,
|
||||
rewriter.getDenseI64ArrayAttr(tileOpShape));
|
||||
|
||||
rewriter.replaceOp(op, {result.getResult()});
|
||||
}
|
||||
|
||||
return success();
|
||||
} else if (selfType.hasRank() &&
|
||||
(selfType.getRank() == (int64_t)resultShape.size() ||
|
||||
selfType.getRank() == 0)) {
|
||||
// Right now to support limited cases where input and result shape are not
|
||||
// equal, we can put a constraint that either the input should be of rank
|
||||
// 0 or the rank of input tensor and result should be equal. And then we
|
||||
// can check for broadcasting compatibility for the latter case. For
|
||||
// broadcasting compatibility, either the shape of input and result should
|
||||
// be equal at each dimenion or one of them should be 1.
|
||||
if (selfType.getRank() != 0) {
|
||||
for (unsigned i = 0; i < inputShape.size(); i++) {
|
||||
if (inputShape[i] != resultShape[i] && inputShape[i] != 1 &&
|
||||
resultShape[i] != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: either the shape of input and result should "
|
||||
"be equal at each dimenion or one of them should be 1.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the above condition hold true then we can directly create a const
|
||||
// zero tensor of shape same as the result shape.
|
||||
SmallVector<int64_t> zeroTensorShape{resultShape};
|
||||
|
||||
// create the 0 constant tensor
|
||||
int64_t totalNumElements = 1;
|
||||
for (auto dimSize : zeroTensorShape) {
|
||||
totalNumElements = dimSize * totalNumElements;
|
||||
}
|
||||
// There is some danger here. For edge cases in floating point, x + 0 != x.
|
||||
// The cases are denormalized values, which may get flushed, and -0 + 0 =
|
||||
// +0. (sign bit flips). These are probably acceptable in the short term,
|
||||
// but we should put a comment acknowledging the danger, as there isn't an
|
||||
// op that avoids the denorm flushing.
|
||||
Value zeroTensor =
|
||||
tosa::getZerosLikeTensor(rewriter, op, resultType).value();
|
||||
|
||||
// Use add broadcast
|
||||
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
|
||||
zeroTensor);
|
||||
return success();
|
||||
}
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"unimplemented: broadcasts other than same rank or zero ranked tensor.");
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -3843,6 +3864,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
|
||||
auto index = adaptor.getIndex();
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
auto indexShape = indexType.getShape();
|
||||
|
||||
if (!indexType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -3851,9 +3873,13 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
auto inputShape = inputType.getShape();
|
||||
int inputRank = inputType.getRank();
|
||||
|
||||
if (indexType.getRank() == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Rank 0 index tensor is currently not supported");
|
||||
if (indexType.getRank() == 0) {
|
||||
indexShape = makeShapeTorchCompatible({1});
|
||||
index = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexShape, indexType.getElementType()), index,
|
||||
rewriter.getDenseI64ArrayAttr(indexShape));
|
||||
}
|
||||
|
||||
// Dynamic shape check
|
||||
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
|
||||
|
@ -3865,9 +3891,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexType.getShape(),
|
||||
rewriter.getIntegerType(32)),
|
||||
index);
|
||||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
|
||||
}
|
||||
|
||||
// Get positive dim
|
||||
|
@ -3896,7 +3920,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
SmallVector<int64_t> indicesInputRankShape;
|
||||
for (int64_t i = 0; i < inputRank; i++) {
|
||||
if (i == dim) {
|
||||
indicesInputRankShape.push_back(indexType.getShape()[0]);
|
||||
indicesInputRankShape.push_back(indexShape[0]);
|
||||
} else {
|
||||
indicesInputRankShape.push_back(1);
|
||||
}
|
||||
|
@ -3952,49 +3976,41 @@ template <>
|
|||
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// a = torch.tensor([[0, 1, 2, 3]])
|
||||
// a[..., 1:] = torch.tensor([4, 5, 6])
|
||||
// = a[..., 1:4] = torch.tensor([4, 5, 6])
|
||||
// = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5,
|
||||
// 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
|
||||
// (torch.tensor([0, 0, 0]), torch.tensor([1, 2,
|
||||
// 3])), # indicies torch.tensor([4, 5, 6])) #
|
||||
// value
|
||||
// = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
|
||||
// (None, torch.tensor([1, 2, 3]),),# indicies
|
||||
// torch.tensor([4, 5, 6])) # value
|
||||
|
||||
// Not a tensor type.
|
||||
auto input = adaptor.getSelf();
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
auto selfType = dyn_cast<TensorType>(input.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
||||
auto fillValues = adaptor.getValues();
|
||||
auto valuesType = dyn_cast<TensorType>(adaptor.getValues().getType());
|
||||
auto valuesType = dyn_cast<TensorType>(fillValues.getType());
|
||||
if (!valuesType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
||||
// Deal with torch.prim.ListConstruct of non const value to get the index
|
||||
// Index_put-like ops are now decomposed to aten.index_put.hacked_twin with
|
||||
// stricter semantics, i.e., no None index in indices argument.
|
||||
auto tensorList = op.getIndices();
|
||||
SmallVector<Value> tensorsTorchType;
|
||||
if (!getListConstructElements(tensorList, tensorsTorchType))
|
||||
return op.emitError(
|
||||
"unimplemented: the tensor list is not from list construct");
|
||||
return op.emitError("Tensor list is not from list construct");
|
||||
auto indexTensors = getTypeConvertedValues(
|
||||
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
// convert list of indices with none into indices tensor without none
|
||||
// indexTensors (none,[1,2,3]) -> ([0,0,0],[1,2,3])
|
||||
// ([[0],[0],[0]],[[1],[2],[3]])-> [[0,1],[0,2], [0,3]]
|
||||
if (indexTensors.size() <= 1) {
|
||||
bool accumulate{false};
|
||||
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support indexput with multiple index.");
|
||||
}
|
||||
op, "Accumulate is not a constant bool value");
|
||||
|
||||
// No support for accumulate mode yet
|
||||
if (accumulate)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Accumulate mode is not currently supported");
|
||||
|
||||
SmallVector<Value> indicesTfConcatTensors;
|
||||
SmallVector<int64_t> indexesRank;
|
||||
SmallVector<SmallVector<int64_t>> indexesShape;
|
||||
|
@ -4002,28 +4018,6 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
// concat index tensor into to indices tensor for concat
|
||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||
auto index = indexTensors[i];
|
||||
auto indexTorch = tensorsTorchType[i];
|
||||
// TODO add support for none index other than i==0, like (index0, None)
|
||||
// (None, index1)
|
||||
if (i == 0 && isa<Torch::NoneType>(indexTorch.getType())) {
|
||||
// convert None to [0,0,0]
|
||||
auto indexNext = indexTensors[i + 1];
|
||||
auto indexNextTorch = tensorsTorchType[i + 1];
|
||||
if (isa<Torch::NoneType>(indexNextTorch.getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Multiple None index is not support for now.");
|
||||
}
|
||||
auto indexNextType = dyn_cast<RankedTensorType>(indexNext.getType());
|
||||
auto indexNextShape = indexNextType.getShape();
|
||||
|
||||
int64_t size = 1;
|
||||
for (auto s : indexNextShape)
|
||||
size *= s;
|
||||
SmallVector<int32_t> values(size, i);
|
||||
index =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, values, indexNextShape)
|
||||
.value();
|
||||
}
|
||||
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
auto indexShape = indexType.getShape();
|
||||
|
@ -4031,20 +4025,19 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
indexesRank.push_back(indexType.getRank());
|
||||
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32))
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
|
||||
index);
|
||||
}
|
||||
|
||||
// Expand last dim of index to tf indices [3] -> [3,1]
|
||||
// convert [0,0,0] to [[0],[0],[0]]
|
||||
SmallVector<int64_t> indiceShapeOneDim;
|
||||
for (auto shape : indexShape) {
|
||||
for (auto shape : indexShape)
|
||||
indiceShapeOneDim.push_back(shape);
|
||||
}
|
||||
indiceShapeOneDim.push_back(1);
|
||||
|
||||
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
|
||||
|
@ -4061,7 +4054,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
for (auto indexShapeOneDim : indexesShape) {
|
||||
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: Only support multi indexes with same shape");
|
||||
op, "Only support indices with same shape");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4075,19 +4068,16 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
|
||||
indicesTfConcatTensors, lastDim);
|
||||
|
||||
if (!indicesTf) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Convert TorchIndex To TfIndices fail.");
|
||||
}
|
||||
// do the tf scatterNd algorithm with tf style indices as input, algorithm
|
||||
// mostly take from convertGatherNdOp.
|
||||
if (!indicesTf)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert PyTorch index to TensorFlow indices failed");
|
||||
|
||||
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
|
||||
indicesTf.getResult(), fillValues);
|
||||
|
||||
if (!result) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert ScatterNdOp fail for index tensor.");
|
||||
}
|
||||
if (!result)
|
||||
return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
|
||||
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
|
||||
return success();
|
||||
|
@ -6632,6 +6622,140 @@ LogicalResult ConvertAtenOp<AtenDiagEmbedOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.uniform
|
||||
// Since TOSA hasn't got a built-in random generator yet, we will use
|
||||
// std::uniform_real_distribution with the std::default_random_engine from C++
|
||||
// <random> library
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
|
||||
AtenUniformOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a tensor type
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
auto selfShape = selfType.getShape();
|
||||
|
||||
auto generator = adaptor.getGenerator();
|
||||
if (!isa<Torch::NoneType>(generator.getType()))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Custom generators are not supported");
|
||||
|
||||
double fromDouble{0.0}, toDouble{1.0};
|
||||
auto isFloat =
|
||||
matchPattern(op.getFrom(), m_TorchConstantFloat(&fromDouble)) &&
|
||||
matchPattern(op.getTo(), m_TorchConstantFloat(&toDouble));
|
||||
|
||||
int64_t fromInt{0}, toInt{1};
|
||||
auto isInt = matchPattern(op.getFrom(), m_TorchConstantInt(&fromInt)) &&
|
||||
matchPattern(op.getTo(), m_TorchConstantInt(&toInt));
|
||||
|
||||
if (!isFloat && !isInt)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "From and To values are not constant values");
|
||||
|
||||
int64_t numElem = 1;
|
||||
for (int64_t i = 0; i < selfType.getRank(); i++)
|
||||
numElem *= selfShape[i];
|
||||
|
||||
auto resultType =
|
||||
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||
|
||||
std::default_random_engine gen;
|
||||
|
||||
auto from = isFloat ? fromDouble : fromInt;
|
||||
auto to = isFloat ? toDouble : toInt;
|
||||
|
||||
std::uniform_real_distribution<float> uniformDist(from, to);
|
||||
SmallVector<float> uniformVec;
|
||||
|
||||
for (int64_t i = 0; i < numElem; i++)
|
||||
uniformVec.push_back(uniformDist(gen));
|
||||
|
||||
auto result = tosa::getConstTensor<float>(rewriter, op, uniformVec, selfShape,
|
||||
selfType.getElementType())
|
||||
.value();
|
||||
|
||||
result = tosa::promoteType(rewriter, result, resultType);
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.threshold_backward
|
||||
// result = self <= threshold ? 0 : grad
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenThresholdBackwardOp>::matchAndRewrite(
|
||||
AtenThresholdBackwardOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a tensor type
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
|
||||
auto selfShape = selfType.getShape();
|
||||
|
||||
auto resultType =
|
||||
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||
auto resultElemTy = resultType.getElementType();
|
||||
|
||||
Value threshold;
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getThreshold(), threshold,
|
||||
selfElemTy, selfShape)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Threshold must be a constant scalar");
|
||||
|
||||
auto grad = adaptor.getGradOutput();
|
||||
|
||||
// Not a tensor type
|
||||
auto gradType = dyn_cast<TensorType>(grad.getType());
|
||||
if (!gradType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
Value zero =
|
||||
TypeSwitch<Type, Value>(resultElemTy)
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return tosa::getConstTensor<float>(rewriter, op, 0, {},
|
||||
resultElemTy)
|
||||
.value();
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 1:
|
||||
return tosa::getConstTensor<bool>(rewriter, op, 0, {}).value();
|
||||
case 8:
|
||||
return tosa::getConstTensor<int8_t>(rewriter, op, 0, {}).value();
|
||||
case 32:
|
||||
return tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
|
||||
case 64:
|
||||
return tosa::getConstTensor<int64_t>(rewriter, op, 0, {}).value();
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
|
||||
// Check: input <= threshold
|
||||
auto cond = rewriter.create<tosa::GreaterEqualOp>(
|
||||
op->getLoc(), RankedTensorType::get(selfShape, rewriter.getI1Type()),
|
||||
threshold, self);
|
||||
|
||||
self = tosa::promoteType(rewriter, self, resultType);
|
||||
grad = tosa::promoteType(rewriter, grad, resultType);
|
||||
|
||||
auto result = rewriter.create<tosa::SelectOp>(op->getLoc(), resultType,
|
||||
cond.getResult(), zero, grad);
|
||||
|
||||
rewriter.replaceOp(op, {result.getResult()});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -6705,6 +6829,7 @@ public:
|
|||
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
|
||||
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
|
||||
tosa::LogicalLeftShiftOp)
|
||||
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
|
||||
|
@ -6947,6 +7072,8 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -1707,9 +1707,17 @@ TOSA_CRASHING_SET = {
|
|||
"ScatterSrcStaticModule_basic",
|
||||
# Runtime op verification: Out of bounds access
|
||||
"ReduceAllDimEmpty_basic",
|
||||
# SmallVector unable to grow for ThresholdBackward1d
|
||||
"ThresholdBackward1dFloatModule_basic",
|
||||
"ThresholdBackward1dIntModule_basic",
|
||||
"ThresholdBackward1dMixedModule_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||
"GridSamplerBasic1_basic",
|
||||
"GridSamplerBasic2_basic",
|
||||
"GridSamplerBasic3_basic",
|
||||
"GridSamplerBasic4_basic",
|
||||
"ScatterSrcModule_basic",
|
||||
"ScatterSrcStaticModule_basic",
|
||||
"HBC_basic",
|
||||
|
@ -1727,6 +1735,25 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"CosineSimilarityStaticBroadcastModule_basic",
|
||||
"DropoutTrainStaticShapeModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"ElementwiseRreluTrainStaticModule_basic",
|
||||
"IndexSelectRank0IdxModule_basic",
|
||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"RandIntDtypeModule_basic",
|
||||
"RandIntLowDtypeModule_basic",
|
||||
"RandModule_basic",
|
||||
"ReduceL3NormAllDimsModule_basic",
|
||||
"ReduceL3NormKeepDimModule_basic",
|
||||
"SliceCopy_Module_basic",
|
||||
"Threshold1dIntModule_basic",
|
||||
"Threshold2dIntModule_basic",
|
||||
"Threshold3dIntModule_basic",
|
||||
"EmptyModule_contiguous",
|
||||
"EmptyModule_defaultDtype",
|
||||
"EmptyModule_falsePinMemory",
|
||||
|
@ -2296,8 +2323,6 @@ TOSA_PASS_SET = {
|
|||
"TensorIntModule_basic",
|
||||
"TensorLiteralModule_basic",
|
||||
"TensorOpaqueLiteralModule_basic",
|
||||
"TensorsConcatNegativeDimStaticModule_basic",
|
||||
"TensorsConcatStaticModule_basic",
|
||||
"TestF16Return_basic",
|
||||
"TestMultipleTensorReturn_basic",
|
||||
"Threshold1dFloatModule_basic",
|
||||
|
@ -2363,7 +2388,6 @@ TOSA_PASS_SET = {
|
|||
"LinspaceModule_basic",
|
||||
"LinspaceOneSizeModule_basic",
|
||||
"LinspaceTwoSizeModule_basic",
|
||||
"TorchPrimLoopForLikeTensorArgModule_basic",
|
||||
"RenormModuleFloat32NegativeDim_basic",
|
||||
"RenormModuleFloat32_basic",
|
||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
|
@ -2468,7 +2492,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
"SplitWithSizesListUnpackModule_basic",
|
||||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
|
@ -2487,7 +2510,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
"ElementwiseLogSigmoidModule_basic",
|
||||
# failed to legalize operation 'torch.aten.rrelu_with_noise'
|
||||
"ElementwiseRreluEvalModule_basic",
|
||||
"ElementwiseRreluEvalStaticModule_basic",
|
||||
# incompatible return type failure for tosa.concat.
|
||||
"HstackBasicComplexModule_basic",
|
||||
"HstackBasicFloatModule_basic",
|
||||
|
@ -3329,6 +3351,14 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||
"MaxPool3dLargeDatadModule_basic",
|
||||
"MaxPool3dModuleRandomSimple_basic",
|
||||
"MaxPool3dModule_basic",
|
||||
"MaxPool3dStaticModule_basic",
|
||||
"ViewDtypeStaticModule_basic",
|
||||
"Unfold_Module_Dynamic_basic",
|
||||
"Unfold_Module_Rank_4",
|
||||
|
@ -3474,7 +3504,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"BoolIntFalseModule_basic",
|
||||
"BoolIntTrueModule_basic",
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
"BroadcastToModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"CollapseAllDimensionsModule_basic",
|
||||
"CollapseFullDynamicModule_basic",
|
||||
|
@ -3509,7 +3538,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ConvolutionModule2DTransposeStrided_basic",
|
||||
"ConvolutionModule2DTranspose_basic",
|
||||
"CopyWithDifferentDTypesModule_basic",
|
||||
"CosineSimilarityStaticBroadcastModule_basic",
|
||||
"CumsumInputDtypeInt32Module_basic",
|
||||
"CumsumModule_basic",
|
||||
"CumsumStaticModule_basic",
|
||||
|
@ -3524,8 +3552,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"DeterminantModule_F32",
|
||||
"DivFloatModule_basic",
|
||||
"DivIntModule_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
"DropoutTrainStaticShapeModule_basic",
|
||||
"ElementwiseAcosIntModule_basic",
|
||||
"ElementwiseAcosModule_basic",
|
||||
"ElementwiseAcoshIntModule_basic",
|
||||
|
@ -3545,11 +3571,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAtanTensorIntModule_basic",
|
||||
"ElementwiseAtanhIntModule_basic",
|
||||
"ElementwiseAtanhModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"ElementwiseClampMinTensorFloatModule_basic",
|
||||
"ElementwiseClampMinTensorIntModule_basic",
|
||||
"ElementwiseClampTensorFloatModule_basic",
|
||||
|
@ -3590,12 +3612,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseUnaryIntModule_basic",
|
||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||
"EqIntModule_basic",
|
||||
"ExpandModule_basic",
|
||||
"ExponentialModule_basic",
|
||||
"FloatImplicitModule_basic",
|
||||
"FullLikeModuleInt2D_basic",
|
||||
"FullLikeModuleInt3D_basic",
|
||||
"FullModuleInt2D_basic",
|
||||
"GeFloatIntModule_basic",
|
||||
"GeFloatModule_basic",
|
||||
"GeIntModule_basic",
|
||||
|
@ -3606,42 +3625,25 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"IndexPut1DFloatAccumulateModule_basic",
|
||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||
"IndexPut1DIntAccumulateModule_basic",
|
||||
"IndexPut1DIntNonAccumulateModule_basic",
|
||||
"IndexPut2DFloatAccumulateModule_basic",
|
||||
"IndexPut2DFloatNonAccumulateModule_basic",
|
||||
"IndexPut2DIntAccumulateModule_basic",
|
||||
"IndexPut2DIntNonAccumulateModule_basic",
|
||||
"IndexPut3DFloatAccumulateModule_basic",
|
||||
"IndexPut3DFloatNonAccumulateModule_basic",
|
||||
"IndexPut3DIntAccumulateModule_basic",
|
||||
"IndexPut3DIntNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
"IndexPutImpl2DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl2DImplicitModule_basic",
|
||||
"IndexPutImpl2DIndexModule_basic",
|
||||
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImplIndexWithNoneModule_basic",
|
||||
"IndexSelectRank0IdxModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
|
@ -3656,8 +3658,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"LinspaceDtypeModule_basic",
|
||||
"LinspaceEmptyModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
"MatmulBroadcastBatchDim_basic",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
"MaskedScatterStaticBasic_basic",
|
||||
"MaxPool1dCeilModeTrueModule_basic",
|
||||
"MaxPool1dModule_basic",
|
||||
"MaxPool2dCeilModeTrueModule_basic",
|
||||
|
@ -3689,17 +3690,16 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
||||
"MaxPool3dWithIndicesStaticModule_basic",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
"MeanDimNoneDimModule_basic",
|
||||
"MseLossMeanReductionModule_basic",
|
||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||
"MlGroupNormManualModule_basic",
|
||||
"MlGroupNormModule_basic",
|
||||
"MlLayerNormManualModule_basic",
|
||||
"MlLayerNormModule_basic",
|
||||
"MulFloatModule_basic",
|
||||
"MulIntModule_basic",
|
||||
"NativeBatchNorm1DModule_basic",
|
||||
"NativeBatchNorm2DModule_basic",
|
||||
"NativeBatchNorm3DModule_basic",
|
||||
"NativeBatchNormNoneWeightModule_basic",
|
||||
"NativeDropoutTrainModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
|
@ -3741,14 +3741,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
"QuantizedSingleLayer_basic",
|
||||
"RandIntDtypeModule_basic",
|
||||
"RandIntLowDtypeModule_basic",
|
||||
"RandIntLowModule_basic",
|
||||
"RandIntModule_basic",
|
||||
"RandIntPinMemoryModule_basic",
|
||||
"RandLikeDtypeModule_basic",
|
||||
"RandLikeModule_basic",
|
||||
"RandModule_basic",
|
||||
"RandnDtypeDeviceModule_basic",
|
||||
"RandnGeneratorF64Module_basic",
|
||||
"RandnGeneratorModule_basic",
|
||||
|
@ -3760,9 +3755,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ReduceL1NormComplexModule_basic",
|
||||
"ReduceL1NormWithDTypeModule_basic",
|
||||
"ReduceL2NormComplexModule_basic",
|
||||
"ReduceL3NormAllDimsModule_basic",
|
||||
"ReduceL3NormKeepDimComplexModule_basic",
|
||||
"ReduceL3NormKeepDimModule_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||
|
@ -3843,18 +3836,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"TensorsConcatPromoteDTypeModule_basic",
|
||||
"TensorsStackPromoteDTypeModule_basic",
|
||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||
"Threshold1dIntModule_basic",
|
||||
"Threshold2dIntModule_basic",
|
||||
"Threshold3dIntModule_basic",
|
||||
"ThresholdBackward1dFloatModule_basic",
|
||||
"ThresholdBackward1dIntModule_basic",
|
||||
"ThresholdBackward1dMixedModule_basic",
|
||||
"ThresholdBackward2dFloatModule_basic",
|
||||
"ThresholdBackward2dIntModule_basic",
|
||||
"ThresholdBackward2dMixedModule_basic",
|
||||
"ThresholdBackward3dFloatModule_basic",
|
||||
"ThresholdBackward3dIntModule_basic",
|
||||
"ThresholdBackward3dMixedModule_basic",
|
||||
"ToCopyWithDTypeFalsePinMemoryModule_basic",
|
||||
"ToCopyWithDTypeModule_basic",
|
||||
"TorchPrimLoopForLikeModule_basic",
|
||||
|
@ -3863,10 +3845,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"TraceUnsignedIntModule_empty",
|
||||
"TypeConversionI1ToF64Module_basic",
|
||||
"TypeConversionI1ToI32Module_basic",
|
||||
"UniformModule_basic",
|
||||
"UniformNoCorrelationModule_basic",
|
||||
"UniformStaticShapeModule_basic",
|
||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"UpSampleNearest2dBackwardScalesNone_basic",
|
||||
"UpSampleNearest2dBackward_basic",
|
||||
|
@ -3875,9 +3853,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"UpSampleNearest2dStaticFactor_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
"UpSampleNearest2d_basic",
|
||||
"VarMeanBiasedModule_basic",
|
||||
"VarMeanCorrectionNoneModule_basic",
|
||||
"VarMeanUnbiasedModule_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"VisionTransformerModule_basic",
|
||||
|
@ -3894,6 +3869,15 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"ElementwiseRreluWithNoiseEvalModule_basic",
|
||||
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||
"RreluWithNoiseBackwardEvalModule_basic",
|
||||
"RreluWithNoiseBackwardEvalStaticModule_basic",
|
||||
"RreluWithNoiseBackwardTrainModule_basic",
|
||||
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||
"RreluWithNoiseForwardBackwardModule_basic",
|
||||
"Unfold_Module_Dynamic_basic",
|
||||
"Unfold_Module_Rank_4",
|
||||
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||
|
@ -3937,12 +3921,10 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
"ElementwiseFmaxModule_basic",
|
||||
"ElementwiseFminModule_basic",
|
||||
"ElementwiseGeluApproximateTanhModule_basic",
|
||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||
"ElementwiseNanToNumWithNoneModule_Basic",
|
||||
"ElementwiseRad2DegIntModule_basic",
|
||||
"ElementwiseRad2DegModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||
|
@ -4106,7 +4088,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"BoolIntConstantModule_basic",
|
||||
"BoolIntFalseModule_basic",
|
||||
"BoolIntTrueModule_basic",
|
||||
"BoolTensorHandleSignless_basic",
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
"BroadcastToModule_basic",
|
||||
"BucketizeTensorFloatModule_basic",
|
||||
|
@ -4123,10 +4104,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"CollapseRank1DynamicModule_basic",
|
||||
"CollapseStaticModule_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
"ConstantPad2dStaticModule_basic",
|
||||
"ConstantPadNdModule_basic",
|
||||
"ConstantPadNdPartialStaticModule_basic",
|
||||
"ConstantPadNdStaticModule_basic",
|
||||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv1dModule_basic",
|
||||
|
@ -4220,9 +4197,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
||||
"ElementwiseAtenIsneginfOpModule_basic",
|
||||
"ElementwiseAtenIsposinfOpModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
|
||||
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
|
||||
|
@ -4254,7 +4229,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseCoshModule_basic",
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivTensorFloatModule_basic",
|
||||
|
@ -4291,7 +4265,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwiseMulTensorFloatModule_basic",
|
||||
"ElementwiseMulTensorIntModule_basic",
|
||||
"ElementwiseNanToNumModule_Basic",
|
||||
"ElementwiseOrTensorModule_basic",
|
||||
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
|
@ -4579,8 +4552,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"OnesLikeModule_falsePinMemory",
|
||||
"OnesLikeModule_float",
|
||||
"OnesLikeModule_int",
|
||||
"PadModule_basic",
|
||||
"PadWithNoneValModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"PixelShuffleModuleFullDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyDynamic_basic",
|
||||
|
@ -4688,7 +4659,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ReflectionPad2dModule_Right",
|
||||
"ReflectionPad2dModule_Top",
|
||||
"ReflectionPad2dModule_basic",
|
||||
"RepeatModule_basic",
|
||||
"ReplicationPad2dModule_basic",
|
||||
"ReplicationPad2dModule_bottom0",
|
||||
"ReplicationPad2dModule_left0",
|
||||
|
|
|
@ -2162,4 +2162,82 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6
|
|||
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
|
||||
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
|
||||
return %1 : !torch.vtensor<[4,2],si64>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.threshold_backward$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],si64> -> tensor<4xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_2]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<4xi1>, tensor<i64>, tensor<4xi64>) -> tensor<4xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<4xi64> -> !torch.vtensor<[4],si64>
|
||||
// CHECK: return %[[VAL_9]] : !torch.vtensor<[4],si64>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.threshold_backward$basic(%arg0: !torch.vtensor<[4],si64>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.threshold_backward %arg0, %arg1, %int1 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.threshold$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],si64> -> tensor<4x5xi64>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.greater %[[VAL_1]], %[[VAL_4]] : (tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi1>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_6]], %[[VAL_1]], %[[VAL_5]] : (tensor<4x5xi1>, tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi64>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4x5xi64> -> !torch.vtensor<[4,5],si64>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,5],si64>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.threshold$basic(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> {
|
||||
%float5.000000e-01 = torch.constant.float 5.000000e-01
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.threshold %arg0, %float5.000000e-01, %int2 : !torch.vtensor<[4,5],si64>, !torch.float, !torch.int -> !torch.vtensor<[4,5],si64>
|
||||
return %0 : !torch.vtensor<[4,5],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_and$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.logical_and %[[VAL_3]], %[[VAL_2]] : (tensor<4x5xi1>, tensor<4x5xi1>) -> tensor<4x5xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,5],i1>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
|
||||
%0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1>
|
||||
return %0 : !torch.vtensor<[4,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.uniform$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64>
|
||||
// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
|
||||
%float1.000000e00 = torch.constant.float 1.000000e+00
|
||||
%float1.000000e01 = torch.constant.float 1.000000e+01
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64>
|
||||
return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue