[TOSA] Expand Torch to TOSA legalization coverage (#3827)

- Add/Extend Torch to TOSA legalization for the following ops:
  + Add aten.threshold_backward
  + Fix aten.threshold
  + Re-implement aten.broadcast_to using tosa.reshape and tosa.tile
  + Add support for rank 0 index for aten.index_select
  + Fix aten.index_put.hacked_twin
  + Add aten.uniform
  + Add aten.logical_and
- Update xfail_sets.py with new e2e results
- Add LIT tests to basic.mlir for newly added ops


Change-Id: I8910564a049d18293284fe2e55e82bc1d2cf10e3

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3839/head
Justin Ngo 2024-10-30 16:26:10 -07:00 committed by GitHub
parent a6292f38ca
commit 4dd213b042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 399 additions and 224 deletions

View File

@ -25,6 +25,7 @@
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
#include <numeric> #include <numeric>
#include <optional> #include <optional>
#include <random>
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -125,15 +126,14 @@ template <typename T>
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
const int64_t &intValue) { const int64_t &intValue) {
if (isFloat) { if (isFloat) {
// Do a round-trip check here instead of numeric limits due to return (doubleValue >=
// compiler warnings around double <-> int conversion. static_cast<double>(std::numeric_limits<T>::min())) &&
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue))); (doubleValue <= static_cast<double>(std::numeric_limits<T>::max()));
} else { } else if (isInt) {
assert(isInt);
return (intValue >= static_cast<int64_t>(std::numeric_limits<T>::min())) && return (intValue >= static_cast<int64_t>(std::numeric_limits<T>::min())) &&
(intValue <= static_cast<int64_t>(std::numeric_limits<T>::max())); (intValue <= static_cast<int64_t>(std::numeric_limits<T>::max()));
} }
return true; return false;
} }
// FIXME: This will eventually go into a Tosa*Utils file. // FIXME: This will eventually go into a Tosa*Utils file.
@ -165,13 +165,13 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
dshape, dtype) dshape, dtype)
.value(); .value();
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) { } else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
auto w = intType.getWidth(); auto width = intType.getWidth();
if (w != 1 && w != 32 && w != 64) if (width != 1 && width != 8 && width != 32 && width != 64)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "Unsupported integer type: " << intType; diag << "Unsupported integer type: " << intType;
}); });
if (w == 1) { if (width == 1) {
if (!isInValidRange<bool>(isFloat, doubleValue, isInt, intValue)) { if (!isInValidRange<bool>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Supplied value of scalar constant exceeds limits " op, "Supplied value of scalar constant exceeds limits "
@ -182,7 +182,18 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
tosaTensor = tosa::getConstTensor<bool>( tosaTensor = tosa::getConstTensor<bool>(
rewriter, op, SmallVector<bool>(numElem, d), dshape) rewriter, op, SmallVector<bool>(numElem, d), dshape)
.value(); .value();
} else if (w == 32) { } else if (width == 8) {
if (!isInValidRange<int8_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure(
op, "Supplied value of scalar constant exceeds limits "
"of destination type");
}
int8_t d = isFloat ? static_cast<int8_t>(doubleValue)
: static_cast<int8_t>(intValue);
tosaTensor = tosa::getConstTensor<int8_t>(
rewriter, op, SmallVector<int8_t>(numElem, d), dshape)
.value();
} else if (width == 32) {
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) { if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Supplied value of scalar constant exceeds limits " op, "Supplied value of scalar constant exceeds limits "
@ -193,7 +204,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
tosaTensor = tosa::getConstTensor<int32_t>( tosaTensor = tosa::getConstTensor<int32_t>(
rewriter, op, SmallVector<int32_t>(numElem, d), dshape) rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
.value(); .value();
} else if (w == 64) { } else if (width == 64) {
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) { if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Supplied value of scalar constant exceeds limits " op, "Supplied value of scalar constant exceeds limits "
@ -919,13 +930,17 @@ class ConvertAtenMultipleDimsReductionOp
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr, ElementsAttr &reduceDimsAttr,
bool &keepDims) const override { bool &keepDims) const override {
SmallVector<int64_t, 4> reduceDims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims)))
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
int64_t N = reduceDims.size();
int64_t inputRank = int64_t inputRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank(); cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
SmallVector<int64_t> reduceDims;
// If dim list is none, all dimensions are reduced
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims))) {
for (int64_t i = 0; i < inputRank; i++)
reduceDims.push_back(i);
}
int64_t N = reduceDims.size();
for (unsigned i = 0; i < N; i++) { for (unsigned i = 0; i < N; i++) {
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank); reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
if (!isValidDim(reduceDims[i], inputRank)) if (!isValidDim(reduceDims[i], inputRank))
@ -2895,9 +2910,10 @@ template <>
LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
AtenThresholdOp op, OpAdaptor adaptor, AtenThresholdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a tensor type. // Not a tensor type.
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType()); auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported"); op, "Only tensor types are currently supported");
@ -2907,12 +2923,9 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported"); op, "Only floating-point or integer datatype legalization supported");
// Integer types with width > 32 are not supported auto outType =
auto selfIntType = dyn_cast<IntegerType>(selfElemTy); dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
if (selfIntType && selfIntType.getWidth() > 32) { auto outElemTy = outType.getElementType();
return rewriter.notifyMatchFailure(
op, "Integer types with width greater than 32 are not supported");
}
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1); SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
Value threshold, value; Value threshold, value;
@ -2922,21 +2935,16 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
op, "Only scalar constant is supported for threshold"); op, "Only scalar constant is supported for threshold");
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), value, if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), value,
selfElemTy, constTypeShape))) outElemTy, constTypeShape)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only scalar constant is supported for value"); op, "Only scalar constant is supported for value");
// Threshold only clamps the upper values. tosa::ClampOp has the same
// value for both threshold and clamped value so cannot be used.
auto outType = getTypeConverter()->convertType(op.getType());
auto cmpOp = rewriter.create<tosa::GreaterOp>( auto cmpOp = rewriter.create<tosa::GreaterOp>(
op.getLoc(), op.getLoc(),
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
adaptor.getSelf(), threshold); self, threshold);
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp, rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp, self, value);
adaptor.getSelf(), value);
return success(); return success();
} }
@ -3660,8 +3668,9 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
AtenBroadcastToOp op, OpAdaptor adaptor, AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a tensor type. // Not a tensor type.
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType()); auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType || !selfType.hasStaticShape()) if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported"); op, "Only tensor types with static shape are supported");
@ -3675,19 +3684,43 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
SmallVector<int64_t> resultShape; SmallVector<int64_t> resultShape;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape))) if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"size must consist of Scalar constants"); "Size must consist of Scalar constants");
int64_t inputRank = selfType.getRank();
int64_t outputRank = resultShape.size();
if (inputRank > outputRank)
return rewriter.notifyMatchFailure(
op, "Input tensor rank cannot be greater than output tensor rank");
// Get the result type // Get the result type
auto resultType = getTypeConverter()->convertType(op.getType()); auto resultType = getTypeConverter()->convertType(op.getType());
SmallVector<int64_t> inputShape( SmallVector<int64_t> inputShape(
makeShapeTorchCompatible(selfType.getShape())); makeShapeTorchCompatible(selfType.getShape()));
// If input rank is smaller than output rank, we reshape the input tensor to
// be the same rank as the output tensor by prepending 1s to the input shape
SmallVector<int64_t> targetInputShape;
for (int64_t i = 0; i < outputRank - inputRank; i++)
targetInputShape.push_back(1);
targetInputShape.append(inputShape);
// Result dimension -1 means not changing the size of that dimension. // Result dimension -1 means not changing the size of that dimension.
// Adjust it by assigning its inputShape. // Adjust it by assigning its inputShape.
for (auto shape : llvm::enumerate(makeShapeTorchCompatible(inputShape))) { for (auto shape :
llvm::enumerate(makeShapeTorchCompatible(targetInputShape))) {
auto index = shape.index(); auto index = shape.index();
if (resultShape[index] == -1) if (resultShape[index] == -1)
resultShape[index] = shape.value(); resultShape[index] = shape.value();
} }
for (int64_t i = 0; i < outputRank; i++) {
if (targetInputShape[i] != resultShape[i] && targetInputShape[i] != 1)
return rewriter.notifyMatchFailure(
op, "Input and result shapes should be equal at each dimension or "
"input shape should be 1");
}
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is // Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand directly. // true then we can replace the op result with the input operand directly.
if (llvm::equal(inputShape, resultShape)) { if (llvm::equal(inputShape, resultShape)) {
@ -3695,52 +3728,40 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
// since the input and result are of same shape. // since the input and result are of same shape.
op.replaceAllUsesWith(op.getSelf()); op.replaceAllUsesWith(op.getSelf());
rewriter.eraseOp(op); rewriter.eraseOp(op);
} else {
// By using reshape and tile ops, support for input rank smaller than result
// rank is allowed. If the rank is smaller, we reshape the input to be the
// same rank as the result, then use tile to expand it. The way it was
// handled before involves adding the input tensor to a const zero tensor of
// output shape to utilize the innate broadcast feature of the TOSA add op.
// That poses the danger of sign bit flips for denormalized values.
// Basically, this approach to broadcast_to legalization allows for more
// flexibility in rank differences and also offers more safety.
Value reshapedInput = self;
if (!llvm::equal(inputShape, targetInputShape))
reshapedInput = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(targetInputShape),
selfElemTy),
self, rewriter.getDenseI64ArrayAttr(targetInputShape));
SmallVector<int64_t> tileOpShape;
for (int64_t i = 0; i < outputRank; i++) {
if (targetInputShape[i] == 1) {
tileOpShape.push_back(resultShape[i]);
} else {
tileOpShape.push_back(1);
}
}
auto result = rewriter.create<tosa::TileOp>(
op->getLoc(), resultType, reshapedInput,
rewriter.getDenseI64ArrayAttr(tileOpShape));
rewriter.replaceOp(op, {result.getResult()});
}
return success(); return success();
} else if (selfType.hasRank() &&
(selfType.getRank() == (int64_t)resultShape.size() ||
selfType.getRank() == 0)) {
// Right now to support limited cases where input and result shape are not
// equal, we can put a constraint that either the input should be of rank
// 0 or the rank of input tensor and result should be equal. And then we
// can check for broadcasting compatibility for the latter case. For
// broadcasting compatibility, either the shape of input and result should
// be equal at each dimenion or one of them should be 1.
if (selfType.getRank() != 0) {
for (unsigned i = 0; i < inputShape.size(); i++) {
if (inputShape[i] != resultShape[i] && inputShape[i] != 1 &&
resultShape[i] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: either the shape of input and result should "
"be equal at each dimenion or one of them should be 1.");
}
}
}
// If the above condition hold true then we can directly create a const
// zero tensor of shape same as the result shape.
SmallVector<int64_t> zeroTensorShape{resultShape};
// create the 0 constant tensor
int64_t totalNumElements = 1;
for (auto dimSize : zeroTensorShape) {
totalNumElements = dimSize * totalNumElements;
}
// There is some danger here. For edge cases in floating point, x + 0 != x.
// The cases are denormalized values, which may get flushed, and -0 + 0 =
// +0. (sign bit flips). These are probably acceptable in the short term,
// but we should put a comment acknowledging the danger, as there isn't an
// op that avoids the denorm flushing.
Value zeroTensor =
tosa::getZerosLikeTensor(rewriter, op, resultType).value();
// Use add broadcast
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
zeroTensor);
return success();
}
return rewriter.notifyMatchFailure(
op,
"unimplemented: broadcasts other than same rank or zero ranked tensor.");
} }
template <> template <>
@ -3843,6 +3864,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
auto index = adaptor.getIndex(); auto index = adaptor.getIndex();
auto indexType = dyn_cast<RankedTensorType>(index.getType()); auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape();
if (!indexType) if (!indexType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -3851,9 +3873,13 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
auto inputShape = inputType.getShape(); auto inputShape = inputType.getShape();
int inputRank = inputType.getRank(); int inputRank = inputType.getRank();
if (indexType.getRank() == 0) if (indexType.getRank() == 0) {
return rewriter.notifyMatchFailure( indexShape = makeShapeTorchCompatible({1});
op, "Rank 0 index tensor is currently not supported"); index = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(indexShape, indexType.getElementType()), index,
rewriter.getDenseI64ArrayAttr(indexShape));
}
// Dynamic shape check // Dynamic shape check
if (!inputType.hasStaticShape() || !indexType.hasStaticShape()) if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
@ -3865,9 +3891,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
if (indexType.getElementType() != rewriter.getIntegerType(32)) { if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>( index = rewriter.create<tosa::CastOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(indexType.getShape(), RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
rewriter.getIntegerType(32)),
index);
} }
// Get positive dim // Get positive dim
@ -3896,7 +3920,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
SmallVector<int64_t> indicesInputRankShape; SmallVector<int64_t> indicesInputRankShape;
for (int64_t i = 0; i < inputRank; i++) { for (int64_t i = 0; i < inputRank; i++) {
if (i == dim) { if (i == dim) {
indicesInputRankShape.push_back(indexType.getShape()[0]); indicesInputRankShape.push_back(indexShape[0]);
} else { } else {
indicesInputRankShape.push_back(1); indicesInputRankShape.push_back(1);
} }
@ -3952,49 +3976,41 @@ template <>
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
// a = torch.tensor([[0, 1, 2, 3]])
// a[..., 1:] = torch.tensor([4, 5, 6])
// = a[..., 1:4] = torch.tensor([4, 5, 6])
// = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5,
// 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
// (torch.tensor([0, 0, 0]), torch.tensor([1, 2,
// 3])), # indicies torch.tensor([4, 5, 6])) #
// value
// = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
// (None, torch.tensor([1, 2, 3]),),# indicies
// torch.tensor([4, 5, 6])) # value
// Not a tensor type. // Not a tensor type.
auto input = adaptor.getSelf(); auto input = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType()); auto selfType = dyn_cast<TensorType>(input.getType());
if (!selfType) if (!selfType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
auto fillValues = adaptor.getValues(); auto fillValues = adaptor.getValues();
auto valuesType = dyn_cast<TensorType>(adaptor.getValues().getType()); auto valuesType = dyn_cast<TensorType>(fillValues.getType());
if (!valuesType) if (!valuesType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
// Deal with torch.prim.ListConstruct of non const value to get the index // Deal with torch.prim.ListConstruct of non const value to get the index
// Index_put-like ops are now decomposed to aten.index_put.hacked_twin with
// stricter semantics, i.e., no None index in indices argument.
auto tensorList = op.getIndices(); auto tensorList = op.getIndices();
SmallVector<Value> tensorsTorchType; SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType)) if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError( return op.emitError("Tensor list is not from list construct");
"unimplemented: the tensor list is not from list construct");
auto indexTensors = getTypeConvertedValues( auto indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType); rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
auto outType = getTypeConverter()->convertType(op.getType()); auto outType = getTypeConverter()->convertType(op.getType());
// convert list of indices with none into indices tensor without none bool accumulate{false};
// indexTensors (none,[1,2,3]) -> ([0,0,0],[1,2,3]) if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
// ([[0],[0],[0]],[[1],[2],[3]])-> [[0,1],[0,2], [0,3]]
if (indexTensors.size() <= 1) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only support indexput with multiple index."); op, "Accumulate is not a constant bool value");
}
// No support for accumulate mode yet
if (accumulate)
return rewriter.notifyMatchFailure(
op, "Accumulate mode is not currently supported");
SmallVector<Value> indicesTfConcatTensors; SmallVector<Value> indicesTfConcatTensors;
SmallVector<int64_t> indexesRank; SmallVector<int64_t> indexesRank;
SmallVector<SmallVector<int64_t>> indexesShape; SmallVector<SmallVector<int64_t>> indexesShape;
@ -4002,28 +4018,6 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
// concat index tensor into to indices tensor for concat // concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) { for (size_t i = 0; i < indexTensors.size(); i++) {
auto index = indexTensors[i]; auto index = indexTensors[i];
auto indexTorch = tensorsTorchType[i];
// TODO add support for none index other than i==0, like (index0, None)
// (None, index1)
if (i == 0 && isa<Torch::NoneType>(indexTorch.getType())) {
// convert None to [0,0,0]
auto indexNext = indexTensors[i + 1];
auto indexNextTorch = tensorsTorchType[i + 1];
if (isa<Torch::NoneType>(indexNextTorch.getType())) {
return rewriter.notifyMatchFailure(
op, "Multiple None index is not support for now.");
}
auto indexNextType = dyn_cast<RankedTensorType>(indexNext.getType());
auto indexNextShape = indexNextType.getShape();
int64_t size = 1;
for (auto s : indexNextShape)
size *= s;
SmallVector<int32_t> values(size, i);
index =
tosa::getConstTensor<int32_t>(rewriter, op, values, indexNextShape)
.value();
}
auto indexType = dyn_cast<RankedTensorType>(index.getType()); auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape(); auto indexShape = indexType.getShape();
@ -4031,20 +4025,19 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
indexesRank.push_back(indexType.getRank()); indexesRank.push_back(indexType.getRank());
// index i64 to i32 for tosa compatible // index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) { if (indexType.getElementType() != rewriter.getIntegerType(32))
index = rewriter.create<tosa::CastOp>( index = rewriter.create<tosa::CastOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
index); index);
}
// Expand last dim of index to tf indices [3] -> [3,1] // Expand last dim of index to tf indices [3] -> [3,1]
// convert [0,0,0] to [[0],[0],[0]] // convert [0,0,0] to [[0],[0],[0]]
SmallVector<int64_t> indiceShapeOneDim; SmallVector<int64_t> indiceShapeOneDim;
for (auto shape : indexShape) { for (auto shape : indexShape)
indiceShapeOneDim.push_back(shape); indiceShapeOneDim.push_back(shape);
}
indiceShapeOneDim.push_back(1); indiceShapeOneDim.push_back(1);
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>( auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), rewriter, op->getLoc(),
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)), RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
@ -4061,7 +4054,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
for (auto indexShapeOneDim : indexesShape) { for (auto indexShapeOneDim : indexesShape) {
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) { if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: Only support multi indexes with same shape"); op, "Only support indices with same shape");
} }
} }
@ -4075,19 +4068,16 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
indicesTfConcatTensors, lastDim); indicesTfConcatTensors, lastDim);
if (!indicesTf) { if (!indicesTf)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(
"Convert TorchIndex To TfIndices fail."); op, "Convert PyTorch index to TensorFlow indices failed");
}
// do the tf scatterNd algorithm with tf style indices as input, algorithm
// mostly take from convertGatherNdOp.
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
indicesTf.getResult(), fillValues); indicesTf.getResult(), fillValues);
if (!result) { if (!result)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
op, "Convert ScatterNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()}); rewriter.replaceOp(op, {result.value()});
return success(); return success();
@ -6632,6 +6622,140 @@ LogicalResult ConvertAtenOp<AtenDiagEmbedOp>::matchAndRewrite(
return success(); return success();
} }
// Legalization for aten.uniform
// Since TOSA hasn't got a built-in random generator yet, we will use
// std::uniform_real_distribution with the std::default_random_engine from C++
// <random> library
template <>
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
AtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a tensor type
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfShape = selfType.getShape();
auto generator = adaptor.getGenerator();
if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(op,
"Custom generators are not supported");
double fromDouble{0.0}, toDouble{1.0};
auto isFloat =
matchPattern(op.getFrom(), m_TorchConstantFloat(&fromDouble)) &&
matchPattern(op.getTo(), m_TorchConstantFloat(&toDouble));
int64_t fromInt{0}, toInt{1};
auto isInt = matchPattern(op.getFrom(), m_TorchConstantInt(&fromInt)) &&
matchPattern(op.getTo(), m_TorchConstantInt(&toInt));
if (!isFloat && !isInt)
return rewriter.notifyMatchFailure(
op, "From and To values are not constant values");
int64_t numElem = 1;
for (int64_t i = 0; i < selfType.getRank(); i++)
numElem *= selfShape[i];
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
std::default_random_engine gen;
auto from = isFloat ? fromDouble : fromInt;
auto to = isFloat ? toDouble : toInt;
std::uniform_real_distribution<float> uniformDist(from, to);
SmallVector<float> uniformVec;
for (int64_t i = 0; i < numElem; i++)
uniformVec.push_back(uniformDist(gen));
auto result = tosa::getConstTensor<float>(rewriter, op, uniformVec, selfShape,
selfType.getElementType())
.value();
result = tosa::promoteType(rewriter, result, resultType);
rewriter.replaceOp(op, {result});
return success();
}
// Legalization for aten.threshold_backward
// result = self <= threshold ? 0 : grad
template <>
LogicalResult ConvertAtenOp<AtenThresholdBackwardOp>::matchAndRewrite(
AtenThresholdBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a tensor type
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfElemTy = selfType.getElementType();
auto selfShape = selfType.getShape();
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();
Value threshold;
if (failed(torchScalarToTosaTensor(rewriter, op, op.getThreshold(), threshold,
selfElemTy, selfShape)))
return rewriter.notifyMatchFailure(op,
"Threshold must be a constant scalar");
auto grad = adaptor.getGradOutput();
// Not a tensor type
auto gradType = dyn_cast<TensorType>(grad.getType());
if (!gradType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
Value zero =
TypeSwitch<Type, Value>(resultElemTy)
.Case<mlir::FloatType>([&](auto) {
return tosa::getConstTensor<float>(rewriter, op, 0, {},
resultElemTy)
.value();
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return tosa::getConstTensor<bool>(rewriter, op, 0, {}).value();
case 8:
return tosa::getConstTensor<int8_t>(rewriter, op, 0, {}).value();
case 32:
return tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
case 64:
return tosa::getConstTensor<int64_t>(rewriter, op, 0, {}).value();
}
llvm_unreachable("Invalid integer width");
});
// Check: input <= threshold
auto cond = rewriter.create<tosa::GreaterEqualOp>(
op->getLoc(), RankedTensorType::get(selfShape, rewriter.getI1Type()),
threshold, self);
self = tosa::promoteType(rewriter, self, resultType);
grad = tosa::promoteType(rewriter, grad, resultType);
auto result = rewriter.create<tosa::SelectOp>(op->getLoc(), resultType,
cond.getResult(), zero, grad);
rewriter.replaceOp(op, {result.getResult()});
return success();
}
} // namespace } // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -6705,6 +6829,7 @@ public:
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
tosa::LogicalLeftShiftOp) tosa::LogicalLeftShiftOp)
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
@ -6947,6 +7072,8 @@ public:
INSERT_ATENOP_PATTERN(AtenScatterSrcOp); INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp); INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -1707,9 +1707,17 @@ TOSA_CRASHING_SET = {
"ScatterSrcStaticModule_basic", "ScatterSrcStaticModule_basic",
# Runtime op verification: Out of bounds access # Runtime op verification: Out of bounds access
"ReduceAllDimEmpty_basic", "ReduceAllDimEmpty_basic",
# SmallVector unable to grow for ThresholdBackward1d
"ThresholdBackward1dFloatModule_basic",
"ThresholdBackward1dIntModule_basic",
"ThresholdBackward1dMixedModule_basic",
} }
FX_IMPORTER_TOSA_CRASHING_SET = { FX_IMPORTER_TOSA_CRASHING_SET = {
"GridSamplerBasic1_basic",
"GridSamplerBasic2_basic",
"GridSamplerBasic3_basic",
"GridSamplerBasic4_basic",
"ScatterSrcModule_basic", "ScatterSrcModule_basic",
"ScatterSrcStaticModule_basic", "ScatterSrcStaticModule_basic",
"HBC_basic", "HBC_basic",
@ -1727,6 +1735,25 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"CosineSimilarityStaticBroadcastModule_basic",
"DropoutTrainStaticShapeModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseRreluTrainStaticModule_basic",
"IndexSelectRank0IdxModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandModule_basic",
"ReduceL3NormAllDimsModule_basic",
"ReduceL3NormKeepDimModule_basic",
"SliceCopy_Module_basic",
"Threshold1dIntModule_basic",
"Threshold2dIntModule_basic",
"Threshold3dIntModule_basic",
"EmptyModule_contiguous", "EmptyModule_contiguous",
"EmptyModule_defaultDtype", "EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory", "EmptyModule_falsePinMemory",
@ -2296,8 +2323,6 @@ TOSA_PASS_SET = {
"TensorIntModule_basic", "TensorIntModule_basic",
"TensorLiteralModule_basic", "TensorLiteralModule_basic",
"TensorOpaqueLiteralModule_basic", "TensorOpaqueLiteralModule_basic",
"TensorsConcatNegativeDimStaticModule_basic",
"TensorsConcatStaticModule_basic",
"TestF16Return_basic", "TestF16Return_basic",
"TestMultipleTensorReturn_basic", "TestMultipleTensorReturn_basic",
"Threshold1dFloatModule_basic", "Threshold1dFloatModule_basic",
@ -2363,7 +2388,6 @@ TOSA_PASS_SET = {
"LinspaceModule_basic", "LinspaceModule_basic",
"LinspaceOneSizeModule_basic", "LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic", "LinspaceTwoSizeModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"RenormModuleFloat32NegativeDim_basic", "RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic", "RenormModuleFloat32_basic",
"IndexTensorStaticContiguousWithNoneModule_basic", "IndexTensorStaticContiguousWithNoneModule_basic",
@ -2468,7 +2492,6 @@ MAKE_FX_TOSA_PASS_SET = (
"SplitWithSizesListUnpackModule_basic", "SplitWithSizesListUnpackModule_basic",
# Dynamic shape, has extra unsupported broadcast ops # Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d", "Matmul_3d",
"MatmulStaticBroadcast_basic",
# Unimplemented operator 'aten._index_put_impl_.hacked_twin' # Unimplemented operator 'aten._index_put_impl_.hacked_twin'
"IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic",
@ -2487,7 +2510,6 @@ MAKE_FX_TOSA_PASS_SET = (
"ElementwiseLogSigmoidModule_basic", "ElementwiseLogSigmoidModule_basic",
# failed to legalize operation 'torch.aten.rrelu_with_noise' # failed to legalize operation 'torch.aten.rrelu_with_noise'
"ElementwiseRreluEvalModule_basic", "ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic",
# incompatible return type failure for tosa.concat. # incompatible return type failure for tosa.concat.
"HstackBasicComplexModule_basic", "HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic", "HstackBasicFloatModule_basic",
@ -3329,6 +3351,14 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
} }
FX_IMPORTER_TOSA_XFAIL_SET = { FX_IMPORTER_TOSA_XFAIL_SET = {
"AdaptiveMaxPool1dDimOneStatic_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MaxPool3dLargeDatadModule_basic",
"MaxPool3dModuleRandomSimple_basic",
"MaxPool3dModule_basic",
"MaxPool3dStaticModule_basic",
"ViewDtypeStaticModule_basic", "ViewDtypeStaticModule_basic",
"Unfold_Module_Dynamic_basic", "Unfold_Module_Dynamic_basic",
"Unfold_Module_Rank_4", "Unfold_Module_Rank_4",
@ -3474,7 +3504,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"BoolIntFalseModule_basic", "BoolIntFalseModule_basic",
"BoolIntTrueModule_basic", "BoolIntTrueModule_basic",
"BroadcastDynamicDimModule_basic", "BroadcastDynamicDimModule_basic",
"BroadcastToModule_basic",
"CeilFloatModule_basic", "CeilFloatModule_basic",
"CollapseAllDimensionsModule_basic", "CollapseAllDimensionsModule_basic",
"CollapseFullDynamicModule_basic", "CollapseFullDynamicModule_basic",
@ -3509,7 +3538,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic", "ConvolutionModule2DTranspose_basic",
"CopyWithDifferentDTypesModule_basic", "CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"CumsumInputDtypeInt32Module_basic", "CumsumInputDtypeInt32Module_basic",
"CumsumModule_basic", "CumsumModule_basic",
"CumsumStaticModule_basic", "CumsumStaticModule_basic",
@ -3524,8 +3552,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"DeterminantModule_F32", "DeterminantModule_F32",
"DivFloatModule_basic", "DivFloatModule_basic",
"DivIntModule_basic", "DivIntModule_basic",
"DropoutTrainModule_basic",
"DropoutTrainStaticShapeModule_basic",
"ElementwiseAcosIntModule_basic", "ElementwiseAcosIntModule_basic",
"ElementwiseAcosModule_basic", "ElementwiseAcosModule_basic",
"ElementwiseAcoshIntModule_basic", "ElementwiseAcoshIntModule_basic",
@ -3545,11 +3571,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanTensorIntModule_basic",
"ElementwiseAtanhIntModule_basic", "ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic", "ElementwiseAtanhModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseClampMinTensorFloatModule_basic", "ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic", "ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampTensorFloatModule_basic", "ElementwiseClampTensorFloatModule_basic",
@ -3590,12 +3612,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseUnaryIntModule_basic", "ElementwiseUnaryIntModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic",
"EqIntModule_basic", "EqIntModule_basic",
"ExpandModule_basic",
"ExponentialModule_basic",
"FloatImplicitModule_basic", "FloatImplicitModule_basic",
"FullLikeModuleInt2D_basic", "FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic", "FullLikeModuleInt3D_basic",
"FullModuleInt2D_basic",
"GeFloatIntModule_basic", "GeFloatIntModule_basic",
"GeFloatModule_basic", "GeFloatModule_basic",
"GeIntModule_basic", "GeIntModule_basic",
@ -3606,42 +3625,25 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"GtFloatIntModule_basic", "GtFloatIntModule_basic",
"GtIntModule_basic", "GtIntModule_basic",
"IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic",
"IndexPut1DIntNonAccumulateModule_basic",
"IndexPut2DFloatAccumulateModule_basic", "IndexPut2DFloatAccumulateModule_basic",
"IndexPut2DFloatNonAccumulateModule_basic",
"IndexPut2DIntAccumulateModule_basic", "IndexPut2DIntAccumulateModule_basic",
"IndexPut2DIntNonAccumulateModule_basic",
"IndexPut3DFloatAccumulateModule_basic", "IndexPut3DFloatAccumulateModule_basic",
"IndexPut3DFloatNonAccumulateModule_basic",
"IndexPut3DIntAccumulateModule_basic", "IndexPut3DIntAccumulateModule_basic",
"IndexPut3DIntNonAccumulateModule_basic",
"IndexPutHackedTwin1DFloatAccumulateModule_basic", "IndexPutHackedTwin1DFloatAccumulateModule_basic",
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin1DIntAccumulateModule_basic", "IndexPutHackedTwin1DIntAccumulateModule_basic",
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
"IndexPutHackedTwin2DFloatAccumulateModule_basic", "IndexPutHackedTwin2DFloatAccumulateModule_basic",
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin2DIntAccumulateModule_basic", "IndexPutHackedTwin2DIntAccumulateModule_basic",
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
"IndexPutHackedTwin3DFloatAccumulateModule_basic", "IndexPutHackedTwin3DFloatAccumulateModule_basic",
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic", "IndexPutHackedTwin3DIntAccumulateModule_basic",
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
"IndexPutImpl1DFloatAccumulateModule_basic", "IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic", "IndexPutImpl1DIntAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexPutImpl2DFloatAccumulateModule_basic", "IndexPutImpl2DFloatAccumulateModule_basic",
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl2DImplicitModule_basic", "IndexPutImpl2DImplicitModule_basic",
"IndexPutImpl2DIndexModule_basic", "IndexPutImpl2DIndexModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic", "IndexPutImplIndexWithNoneModule_basic",
"IndexSelectRank0IdxModule_basic",
"InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners", "InterpolateStaticModule_scales_bilinear_align_corners",
@ -3656,8 +3658,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"LinspaceDtypeModule_basic", "LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic", "LinspaceEmptyModule_basic",
"MaskedFillTensorFloatValueModule_basic", "MaskedFillTensorFloatValueModule_basic",
"MatmulBroadcastBatchDim_basic", "MaskedScatterStaticBasic_basic",
"MatmulStaticBroadcast_basic",
"MaxPool1dCeilModeTrueModule_basic", "MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic", "MaxPool1dModule_basic",
"MaxPool2dCeilModeTrueModule_basic", "MaxPool2dCeilModeTrueModule_basic",
@ -3689,17 +3690,16 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic",
"MaxPool3dWithIndicesStaticModule_basic", "MaxPool3dWithIndicesStaticModule_basic",
"MeanDimEmptyDimModule_basic", "MeanDimEmptyDimModule_basic",
"MeanDimNoneDimModule_basic", "MlGroupNormManualModule_basic",
"MseLossMeanReductionModule_basic", "MlGroupNormModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic", "MlLayerNormManualModule_basic",
"MlLayerNormModule_basic",
"MulFloatModule_basic", "MulFloatModule_basic",
"MulIntModule_basic", "MulIntModule_basic",
"NativeBatchNorm1DModule_basic", "NativeBatchNorm1DModule_basic",
"NativeBatchNorm2DModule_basic", "NativeBatchNorm2DModule_basic",
"NativeBatchNorm3DModule_basic", "NativeBatchNorm3DModule_basic",
"NativeBatchNormNoneWeightModule_basic", "NativeBatchNormNoneWeightModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"NativeGroupNormBackwardModule_basic", "NativeGroupNormBackwardModule_basic",
"NeFloatIntModule_basic", "NeFloatIntModule_basic",
"NeIntModule_basic", "NeIntModule_basic",
@ -3741,14 +3741,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"QuantizedReluInt8_basic", "QuantizedReluInt8_basic",
"QuantizedReluUint8_basic", "QuantizedReluUint8_basic",
"QuantizedSingleLayer_basic", "QuantizedSingleLayer_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandIntLowModule_basic", "RandIntLowModule_basic",
"RandIntModule_basic", "RandIntModule_basic",
"RandIntPinMemoryModule_basic", "RandIntPinMemoryModule_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"RandModule_basic",
"RandnDtypeDeviceModule_basic", "RandnDtypeDeviceModule_basic",
"RandnGeneratorF64Module_basic", "RandnGeneratorF64Module_basic",
"RandnGeneratorModule_basic", "RandnGeneratorModule_basic",
@ -3760,9 +3755,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReduceL1NormComplexModule_basic", "ReduceL1NormComplexModule_basic",
"ReduceL1NormWithDTypeModule_basic", "ReduceL1NormWithDTypeModule_basic",
"ReduceL2NormComplexModule_basic", "ReduceL2NormComplexModule_basic",
"ReduceL3NormAllDimsModule_basic",
"ReduceL3NormKeepDimComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic", "ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic",
"ReduceSumDimIntListEmptyDimModule_basic", "ReduceSumDimIntListEmptyDimModule_basic",
@ -3843,18 +3836,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"TensorsConcatPromoteDTypeModule_basic", "TensorsConcatPromoteDTypeModule_basic",
"TensorsStackPromoteDTypeModule_basic", "TensorsStackPromoteDTypeModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic", "TestMultipleTensorAndPrimitiveTypesReturn_basic",
"Threshold1dIntModule_basic",
"Threshold2dIntModule_basic",
"Threshold3dIntModule_basic",
"ThresholdBackward1dFloatModule_basic",
"ThresholdBackward1dIntModule_basic",
"ThresholdBackward1dMixedModule_basic",
"ThresholdBackward2dFloatModule_basic",
"ThresholdBackward2dIntModule_basic",
"ThresholdBackward2dMixedModule_basic", "ThresholdBackward2dMixedModule_basic",
"ThresholdBackward3dFloatModule_basic",
"ThresholdBackward3dIntModule_basic",
"ThresholdBackward3dMixedModule_basic",
"ToCopyWithDTypeFalsePinMemoryModule_basic", "ToCopyWithDTypeFalsePinMemoryModule_basic",
"ToCopyWithDTypeModule_basic", "ToCopyWithDTypeModule_basic",
"TorchPrimLoopForLikeModule_basic", "TorchPrimLoopForLikeModule_basic",
@ -3863,10 +3845,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"TraceUnsignedIntModule_empty", "TraceUnsignedIntModule_empty",
"TypeConversionI1ToF64Module_basic", "TypeConversionI1ToF64Module_basic",
"TypeConversionI1ToI32Module_basic", "TypeConversionI1ToI32Module_basic",
"UniformModule_basic",
"UniformNoCorrelationModule_basic",
"UniformStaticShapeModule_basic",
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackwardScalesNone_basic",
"UpSampleNearest2dBackward_basic", "UpSampleNearest2dBackward_basic",
@ -3875,9 +3853,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic", "UpSampleNearest2dStaticSize_basic",
"UpSampleNearest2d_basic", "UpSampleNearest2d_basic",
"VarMeanBiasedModule_basic",
"VarMeanCorrectionNoneModule_basic",
"VarMeanUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic", "ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic", "VisionTransformerModule_basic",
@ -3894,6 +3869,15 @@ ONNX_TOSA_CRASHING_SET = {
} }
ONNX_TOSA_XFAIL_SET = { ONNX_TOSA_XFAIL_SET = {
"ElementwiseRreluWithNoiseEvalModule_basic",
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"RreluWithNoiseBackwardEvalModule_basic",
"RreluWithNoiseBackwardEvalStaticModule_basic",
"RreluWithNoiseBackwardTrainModule_basic",
"RreluWithNoiseBackwardTrainStaticModule_basic",
"RreluWithNoiseForwardBackwardModule_basic",
"Unfold_Module_Dynamic_basic", "Unfold_Module_Dynamic_basic",
"Unfold_Module_Rank_4", "Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_Size_Zero_basic", "Unfold_Module_Rank_Zero_Size_Zero_basic",
@ -3937,12 +3921,10 @@ ONNX_TOSA_XFAIL_SET = {
"Conv_Transpose2dStaticModule_basic", "Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic", "Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic", "Conv_Transpose3dStaticModule_basic",
"EinsumStaticModule_basic",
"ElementwiseFmaxModule_basic", "ElementwiseFmaxModule_basic",
"ElementwiseFminModule_basic", "ElementwiseFminModule_basic",
"ElementwiseGeluApproximateTanhModule_basic", "ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic", "ElementwiseIntTensorLtFloatTensorModule_basic",
"ElementwiseNanToNumWithNoneModule_Basic",
"ElementwiseRad2DegIntModule_basic", "ElementwiseRad2DegIntModule_basic",
"ElementwiseRad2DegModule_basic", "ElementwiseRad2DegModule_basic",
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
@ -4106,7 +4088,6 @@ ONNX_TOSA_XFAIL_SET = {
"BoolIntConstantModule_basic", "BoolIntConstantModule_basic",
"BoolIntFalseModule_basic", "BoolIntFalseModule_basic",
"BoolIntTrueModule_basic", "BoolIntTrueModule_basic",
"BoolTensorHandleSignless_basic",
"BroadcastDynamicDimModule_basic", "BroadcastDynamicDimModule_basic",
"BroadcastToModule_basic", "BroadcastToModule_basic",
"BucketizeTensorFloatModule_basic", "BucketizeTensorFloatModule_basic",
@ -4123,10 +4104,6 @@ ONNX_TOSA_XFAIL_SET = {
"CollapseRank1DynamicModule_basic", "CollapseRank1DynamicModule_basic",
"CollapseStaticModule_basic", "CollapseStaticModule_basic",
"ConstantBoolParameterModule_basic", "ConstantBoolParameterModule_basic",
"ConstantPad2dStaticModule_basic",
"ConstantPadNdModule_basic",
"ConstantPadNdPartialStaticModule_basic",
"ConstantPadNdStaticModule_basic",
"ContainsIntList_False", "ContainsIntList_False",
"ContainsIntList_True", "ContainsIntList_True",
"Conv1dModule_basic", "Conv1dModule_basic",
@ -4220,9 +4197,7 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseAtenFloorDivideTensorPositiveModule_basic", "ElementwiseAtenFloorDivideTensorPositiveModule_basic",
"ElementwiseAtenIsneginfOpModule_basic", "ElementwiseAtenIsneginfOpModule_basic",
"ElementwiseAtenIsposinfOpModule_basic", "ElementwiseAtenIsposinfOpModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseAtenLogicalOrOpBrodcastModule_basic", "ElementwiseAtenLogicalOrOpBrodcastModule_basic",
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
@ -4254,7 +4229,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseCoshModule_basic", "ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic", "ElementwiseDequantizePerTensorModule_basic",
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncModule_basic", "ElementwiseDivScalarRoundingModeTruncModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic", "ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorFloatModule_basic", "ElementwiseDivTensorFloatModule_basic",
@ -4291,7 +4265,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorFloatModule_basic", "ElementwiseMulTensorFloatModule_basic",
"ElementwiseMulTensorIntModule_basic", "ElementwiseMulTensorIntModule_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseOrTensorModule_basic", "ElementwiseOrTensorModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic",
@ -4579,8 +4552,6 @@ ONNX_TOSA_XFAIL_SET = {
"OnesLikeModule_falsePinMemory", "OnesLikeModule_falsePinMemory",
"OnesLikeModule_float", "OnesLikeModule_float",
"OnesLikeModule_int", "OnesLikeModule_int",
"PadModule_basic",
"PadWithNoneValModule_basic",
"PermuteNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic",
"PixelShuffleModuleFullDynamic_basic", "PixelShuffleModuleFullDynamic_basic",
"PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyDynamic_basic",
@ -4688,7 +4659,6 @@ ONNX_TOSA_XFAIL_SET = {
"ReflectionPad2dModule_Right", "ReflectionPad2dModule_Right",
"ReflectionPad2dModule_Top", "ReflectionPad2dModule_Top",
"ReflectionPad2dModule_basic", "ReflectionPad2dModule_basic",
"RepeatModule_basic",
"ReplicationPad2dModule_basic", "ReplicationPad2dModule_basic",
"ReplicationPad2dModule_bottom0", "ReplicationPad2dModule_bottom0",
"ReplicationPad2dModule_left0", "ReplicationPad2dModule_left0",

View File

@ -2162,4 +2162,82 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor> %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64> %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
return %1 : !torch.vtensor<[4,2],si64> return %1 : !torch.vtensor<[4,2],si64>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.threshold_backward$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],si64> -> tensor<4xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_2]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1>
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<4xi1>, tensor<i64>, tensor<4xi64>) -> tensor<4xi64>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<4xi64> -> !torch.vtensor<[4],si64>
// CHECK: return %[[VAL_9]] : !torch.vtensor<[4],si64>
// CHECK: }
func.func @torch.aten.threshold_backward$basic(%arg0: !torch.vtensor<[4],si64>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> {
%int1 = torch.constant.int 1
%0 = torch.aten.threshold_backward %arg0, %arg1, %int1 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.threshold$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],si64> -> tensor<4x5xi64>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
// CHECK: %[[VAL_6:.*]] = tosa.greater %[[VAL_1]], %[[VAL_4]] : (tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi1>
// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_6]], %[[VAL_1]], %[[VAL_5]] : (tensor<4x5xi1>, tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi64>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4x5xi64> -> !torch.vtensor<[4,5],si64>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,5],si64>
// CHECK: }
func.func @torch.aten.threshold$basic(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> {
%float5.000000e-01 = torch.constant.float 5.000000e-01
%int2 = torch.constant.int 2
%0 = torch.aten.threshold %arg0, %float5.000000e-01, %int2 : !torch.vtensor<[4,5],si64>, !torch.float, !torch.int -> !torch.vtensor<[4,5],si64>
return %0 : !torch.vtensor<[4,5],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_and$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
// CHECK: %[[VAL_4:.*]] = tosa.logical_and %[[VAL_3]], %[[VAL_2]] : (tensor<4x5xi1>, tensor<4x5xi1>) -> tensor<4x5xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,5],i1>
// CHECK: }
func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
%0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1>
return %0 : !torch.vtensor<[4,5],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.uniform$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64>
// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
// CHECK: }
func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
%float1.000000e00 = torch.constant.float 1.000000e+00
%float1.000000e01 = torch.constant.float 1.000000e+01
%none = torch.constant.none
%0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64>
return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
}