[TOSA] Expand Torch to TOSA legalization coverage (#3827)

- Add/Extend Torch to TOSA legalization for the following ops:
  + Add aten.threshold_backward
  + Fix aten.threshold
  + Re-implement aten.broadcast_to using tosa.reshape and tosa.tile
  + Add support for rank 0 index for aten.index_select
  + Fix aten.index_put.hacked_twin
  + Add aten.uniform
  + Add aten.logical_and
- Update xfail_sets.py with new e2e results
- Add LIT tests to basic.mlir for newly added ops


Change-Id: I8910564a049d18293284fe2e55e82bc1d2cf10e3

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3839/head
Justin Ngo 2024-10-30 16:26:10 -07:00 committed by GitHub
parent a6292f38ca
commit 4dd213b042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 399 additions and 224 deletions

View File

@ -25,6 +25,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include <numeric>
#include <optional>
#include <random>
using namespace mlir;
using namespace mlir::torch;
@ -125,15 +126,14 @@ template <typename T>
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
const int64_t &intValue) {
if (isFloat) {
// Do a round-trip check here instead of numeric limits due to
// compiler warnings around double <-> int conversion.
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue)));
} else {
assert(isInt);
return (doubleValue >=
static_cast<double>(std::numeric_limits<T>::min())) &&
(doubleValue <= static_cast<double>(std::numeric_limits<T>::max()));
} else if (isInt) {
return (intValue >= static_cast<int64_t>(std::numeric_limits<T>::min())) &&
(intValue <= static_cast<int64_t>(std::numeric_limits<T>::max()));
}
return true;
return false;
}
// FIXME: This will eventually go into a Tosa*Utils file.
@ -165,13 +165,13 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
dshape, dtype)
.value();
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
auto w = intType.getWidth();
if (w != 1 && w != 32 && w != 64)
auto width = intType.getWidth();
if (width != 1 && width != 8 && width != 32 && width != 64)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "Unsupported integer type: " << intType;
});
if (w == 1) {
if (width == 1) {
if (!isInValidRange<bool>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure(
op, "Supplied value of scalar constant exceeds limits "
@ -182,7 +182,18 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
tosaTensor = tosa::getConstTensor<bool>(
rewriter, op, SmallVector<bool>(numElem, d), dshape)
.value();
} else if (w == 32) {
} else if (width == 8) {
if (!isInValidRange<int8_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure(
op, "Supplied value of scalar constant exceeds limits "
"of destination type");
}
int8_t d = isFloat ? static_cast<int8_t>(doubleValue)
: static_cast<int8_t>(intValue);
tosaTensor = tosa::getConstTensor<int8_t>(
rewriter, op, SmallVector<int8_t>(numElem, d), dshape)
.value();
} else if (width == 32) {
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure(
op, "Supplied value of scalar constant exceeds limits "
@ -193,7 +204,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
tosaTensor = tosa::getConstTensor<int32_t>(
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
.value();
} else if (w == 64) {
} else if (width == 64) {
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure(
op, "Supplied value of scalar constant exceeds limits "
@ -919,13 +930,17 @@ class ConvertAtenMultipleDimsReductionOp
ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr,
bool &keepDims) const override {
SmallVector<int64_t, 4> reduceDims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims)))
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
int64_t N = reduceDims.size();
int64_t inputRank =
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
SmallVector<int64_t> reduceDims;
// If dim list is none, all dimensions are reduced
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims))) {
for (int64_t i = 0; i < inputRank; i++)
reduceDims.push_back(i);
}
int64_t N = reduceDims.size();
for (unsigned i = 0; i < N; i++) {
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
if (!isValidDim(reduceDims[i], inputRank))
@ -2895,9 +2910,10 @@ template <>
LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
AtenThresholdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a tensor type.
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported");
@ -2907,12 +2923,9 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");
// Integer types with width > 32 are not supported
auto selfIntType = dyn_cast<IntegerType>(selfElemTy);
if (selfIntType && selfIntType.getWidth() > 32) {
return rewriter.notifyMatchFailure(
op, "Integer types with width greater than 32 are not supported");
}
auto outType =
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
auto outElemTy = outType.getElementType();
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
Value threshold, value;
@ -2922,21 +2935,16 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
op, "Only scalar constant is supported for threshold");
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), value,
selfElemTy, constTypeShape)))
outElemTy, constTypeShape)))
return rewriter.notifyMatchFailure(
op, "Only scalar constant is supported for value");
// Threshold only clamps the upper values. tosa::ClampOp has the same
// value for both threshold and clamped value so cannot be used.
auto outType = getTypeConverter()->convertType(op.getType());
auto cmpOp = rewriter.create<tosa::GreaterOp>(
op.getLoc(),
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
adaptor.getSelf(), threshold);
self, threshold);
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp,
adaptor.getSelf(), value);
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp, self, value);
return success();
}
@ -3660,8 +3668,9 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a tensor type.
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
@ -3675,19 +3684,43 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
SmallVector<int64_t> resultShape;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape)))
return rewriter.notifyMatchFailure(op,
"size must consist of Scalar constants");
"Size must consist of Scalar constants");
int64_t inputRank = selfType.getRank();
int64_t outputRank = resultShape.size();
if (inputRank > outputRank)
return rewriter.notifyMatchFailure(
op, "Input tensor rank cannot be greater than output tensor rank");
// Get the result type
auto resultType = getTypeConverter()->convertType(op.getType());
SmallVector<int64_t> inputShape(
makeShapeTorchCompatible(selfType.getShape()));
// If input rank is smaller than output rank, we reshape the input tensor to
// be the same rank as the output tensor by prepending 1s to the input shape
SmallVector<int64_t> targetInputShape;
for (int64_t i = 0; i < outputRank - inputRank; i++)
targetInputShape.push_back(1);
targetInputShape.append(inputShape);
// Result dimension -1 means not changing the size of that dimension.
// Adjust it by assigning its inputShape.
for (auto shape : llvm::enumerate(makeShapeTorchCompatible(inputShape))) {
for (auto shape :
llvm::enumerate(makeShapeTorchCompatible(targetInputShape))) {
auto index = shape.index();
if (resultShape[index] == -1)
resultShape[index] = shape.value();
}
for (int64_t i = 0; i < outputRank; i++) {
if (targetInputShape[i] != resultShape[i] && targetInputShape[i] != 1)
return rewriter.notifyMatchFailure(
op, "Input and result shapes should be equal at each dimension or "
"input shape should be 1");
}
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand directly.
if (llvm::equal(inputShape, resultShape)) {
@ -3695,52 +3728,40 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
// since the input and result are of same shape.
op.replaceAllUsesWith(op.getSelf());
rewriter.eraseOp(op);
} else {
// By using reshape and tile ops, support for input rank smaller than result
// rank is allowed. If the rank is smaller, we reshape the input to be the
// same rank as the result, then use tile to expand it. The way it was
// handled before involves adding the input tensor to a const zero tensor of
// output shape to utilize the innate broadcast feature of the TOSA add op.
// That poses the danger of sign bit flips for denormalized values.
// Basically, this approach to broadcast_to legalization allows for more
// flexibility in rank differences and also offers more safety.
Value reshapedInput = self;
if (!llvm::equal(inputShape, targetInputShape))
reshapedInput = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(targetInputShape),
selfElemTy),
self, rewriter.getDenseI64ArrayAttr(targetInputShape));
SmallVector<int64_t> tileOpShape;
for (int64_t i = 0; i < outputRank; i++) {
if (targetInputShape[i] == 1) {
tileOpShape.push_back(resultShape[i]);
} else {
tileOpShape.push_back(1);
}
}
auto result = rewriter.create<tosa::TileOp>(
op->getLoc(), resultType, reshapedInput,
rewriter.getDenseI64ArrayAttr(tileOpShape));
rewriter.replaceOp(op, {result.getResult()});
}
return success();
} else if (selfType.hasRank() &&
(selfType.getRank() == (int64_t)resultShape.size() ||
selfType.getRank() == 0)) {
// Right now to support limited cases where input and result shape are not
// equal, we can put a constraint that either the input should be of rank
// 0 or the rank of input tensor and result should be equal. And then we
// can check for broadcasting compatibility for the latter case. For
// broadcasting compatibility, either the shape of input and result should
// be equal at each dimenion or one of them should be 1.
if (selfType.getRank() != 0) {
for (unsigned i = 0; i < inputShape.size(); i++) {
if (inputShape[i] != resultShape[i] && inputShape[i] != 1 &&
resultShape[i] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: either the shape of input and result should "
"be equal at each dimenion or one of them should be 1.");
}
}
}
// If the above condition hold true then we can directly create a const
// zero tensor of shape same as the result shape.
SmallVector<int64_t> zeroTensorShape{resultShape};
// create the 0 constant tensor
int64_t totalNumElements = 1;
for (auto dimSize : zeroTensorShape) {
totalNumElements = dimSize * totalNumElements;
}
// There is some danger here. For edge cases in floating point, x + 0 != x.
// The cases are denormalized values, which may get flushed, and -0 + 0 =
// +0. (sign bit flips). These are probably acceptable in the short term,
// but we should put a comment acknowledging the danger, as there isn't an
// op that avoids the denorm flushing.
Value zeroTensor =
tosa::getZerosLikeTensor(rewriter, op, resultType).value();
// Use add broadcast
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
zeroTensor);
return success();
}
return rewriter.notifyMatchFailure(
op,
"unimplemented: broadcasts other than same rank or zero ranked tensor.");
}
template <>
@ -3843,6 +3864,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
auto index = adaptor.getIndex();
auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape();
if (!indexType)
return rewriter.notifyMatchFailure(
@ -3851,9 +3873,13 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
auto inputShape = inputType.getShape();
int inputRank = inputType.getRank();
if (indexType.getRank() == 0)
return rewriter.notifyMatchFailure(
op, "Rank 0 index tensor is currently not supported");
if (indexType.getRank() == 0) {
indexShape = makeShapeTorchCompatible({1});
index = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(indexShape, indexType.getElementType()), index,
rewriter.getDenseI64ArrayAttr(indexShape));
}
// Dynamic shape check
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
@ -3865,9 +3891,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexType.getShape(),
rewriter.getIntegerType(32)),
index);
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}
// Get positive dim
@ -3896,7 +3920,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
SmallVector<int64_t> indicesInputRankShape;
for (int64_t i = 0; i < inputRank; i++) {
if (i == dim) {
indicesInputRankShape.push_back(indexType.getShape()[0]);
indicesInputRankShape.push_back(indexShape[0]);
} else {
indicesInputRankShape.push_back(1);
}
@ -3952,49 +3976,41 @@ template <>
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// a = torch.tensor([[0, 1, 2, 3]])
// a[..., 1:] = torch.tensor([4, 5, 6])
// = a[..., 1:4] = torch.tensor([4, 5, 6])
// = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5,
// 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
// (torch.tensor([0, 0, 0]), torch.tensor([1, 2,
// 3])), # indicies torch.tensor([4, 5, 6])) #
// value
// = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
// (None, torch.tensor([1, 2, 3]),),# indicies
// torch.tensor([4, 5, 6])) # value
// Not a tensor type.
auto input = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
auto selfType = dyn_cast<TensorType>(input.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
auto fillValues = adaptor.getValues();
auto valuesType = dyn_cast<TensorType>(adaptor.getValues().getType());
auto valuesType = dyn_cast<TensorType>(fillValues.getType());
if (!valuesType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
// Deal with torch.prim.ListConstruct of non const value to get the index
// Index_put-like ops are now decomposed to aten.index_put.hacked_twin with
// stricter semantics, i.e., no None index in indices argument.
auto tensorList = op.getIndices();
SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
return op.emitError("Tensor list is not from list construct");
auto indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
auto outType = getTypeConverter()->convertType(op.getType());
// convert list of indices with none into indices tensor without none
// indexTensors (none,[1,2,3]) -> ([0,0,0],[1,2,3])
// ([[0],[0],[0]],[[1],[2],[3]])-> [[0,1],[0,2], [0,3]]
if (indexTensors.size() <= 1) {
bool accumulate{false};
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
return rewriter.notifyMatchFailure(
op, "Only support indexput with multiple index.");
}
op, "Accumulate is not a constant bool value");
// No support for accumulate mode yet
if (accumulate)
return rewriter.notifyMatchFailure(
op, "Accumulate mode is not currently supported");
SmallVector<Value> indicesTfConcatTensors;
SmallVector<int64_t> indexesRank;
SmallVector<SmallVector<int64_t>> indexesShape;
@ -4002,28 +4018,6 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
auto index = indexTensors[i];
auto indexTorch = tensorsTorchType[i];
// TODO add support for none index other than i==0, like (index0, None)
// (None, index1)
if (i == 0 && isa<Torch::NoneType>(indexTorch.getType())) {
// convert None to [0,0,0]
auto indexNext = indexTensors[i + 1];
auto indexNextTorch = tensorsTorchType[i + 1];
if (isa<Torch::NoneType>(indexNextTorch.getType())) {
return rewriter.notifyMatchFailure(
op, "Multiple None index is not support for now.");
}
auto indexNextType = dyn_cast<RankedTensorType>(indexNext.getType());
auto indexNextShape = indexNextType.getShape();
int64_t size = 1;
for (auto s : indexNextShape)
size *= s;
SmallVector<int32_t> values(size, i);
index =
tosa::getConstTensor<int32_t>(rewriter, op, values, indexNextShape)
.value();
}
auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape();
@ -4031,20 +4025,19 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
indexesRank.push_back(indexType.getRank());
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
if (indexType.getElementType() != rewriter.getIntegerType(32))
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
index);
}
// Expand last dim of index to tf indices [3] -> [3,1]
// convert [0,0,0] to [[0],[0],[0]]
SmallVector<int64_t> indiceShapeOneDim;
for (auto shape : indexShape) {
for (auto shape : indexShape)
indiceShapeOneDim.push_back(shape);
}
indiceShapeOneDim.push_back(1);
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
@ -4061,7 +4054,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
for (auto indexShapeOneDim : indexesShape) {
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: Only support multi indexes with same shape");
op, "Only support indices with same shape");
}
}
@ -4075,19 +4068,16 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
indicesTfConcatTensors, lastDim);
if (!indicesTf) {
return rewriter.notifyMatchFailure(op,
"Convert TorchIndex To TfIndices fail.");
}
// do the tf scatterNd algorithm with tf style indices as input, algorithm
// mostly take from convertGatherNdOp.
if (!indicesTf)
return rewriter.notifyMatchFailure(
op, "Convert PyTorch index to TensorFlow indices failed");
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
indicesTf.getResult(), fillValues);
if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert ScatterNdOp fail for index tensor.");
}
if (!result)
return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
rewriter.replaceOp(op, {result.value()});
return success();
@ -6632,6 +6622,140 @@ LogicalResult ConvertAtenOp<AtenDiagEmbedOp>::matchAndRewrite(
return success();
}
// Legalization for aten.uniform
// Since TOSA hasn't got a built-in random generator yet, we will use
// std::uniform_real_distribution with the std::default_random_engine from C++
// <random> library
template <>
LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
AtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a tensor type
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfShape = selfType.getShape();
auto generator = adaptor.getGenerator();
if (!isa<Torch::NoneType>(generator.getType()))
return rewriter.notifyMatchFailure(op,
"Custom generators are not supported");
double fromDouble{0.0}, toDouble{1.0};
auto isFloat =
matchPattern(op.getFrom(), m_TorchConstantFloat(&fromDouble)) &&
matchPattern(op.getTo(), m_TorchConstantFloat(&toDouble));
int64_t fromInt{0}, toInt{1};
auto isInt = matchPattern(op.getFrom(), m_TorchConstantInt(&fromInt)) &&
matchPattern(op.getTo(), m_TorchConstantInt(&toInt));
if (!isFloat && !isInt)
return rewriter.notifyMatchFailure(
op, "From and To values are not constant values");
int64_t numElem = 1;
for (int64_t i = 0; i < selfType.getRank(); i++)
numElem *= selfShape[i];
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
std::default_random_engine gen;
auto from = isFloat ? fromDouble : fromInt;
auto to = isFloat ? toDouble : toInt;
std::uniform_real_distribution<float> uniformDist(from, to);
SmallVector<float> uniformVec;
for (int64_t i = 0; i < numElem; i++)
uniformVec.push_back(uniformDist(gen));
auto result = tosa::getConstTensor<float>(rewriter, op, uniformVec, selfShape,
selfType.getElementType())
.value();
result = tosa::promoteType(rewriter, result, resultType);
rewriter.replaceOp(op, {result});
return success();
}
// Legalization for aten.threshold_backward
// result = self <= threshold ? 0 : grad
template <>
LogicalResult ConvertAtenOp<AtenThresholdBackwardOp>::matchAndRewrite(
AtenThresholdBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a tensor type
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfElemTy = selfType.getElementType();
auto selfShape = selfType.getShape();
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();
Value threshold;
if (failed(torchScalarToTosaTensor(rewriter, op, op.getThreshold(), threshold,
selfElemTy, selfShape)))
return rewriter.notifyMatchFailure(op,
"Threshold must be a constant scalar");
auto grad = adaptor.getGradOutput();
// Not a tensor type
auto gradType = dyn_cast<TensorType>(grad.getType());
if (!gradType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
Value zero =
TypeSwitch<Type, Value>(resultElemTy)
.Case<mlir::FloatType>([&](auto) {
return tosa::getConstTensor<float>(rewriter, op, 0, {},
resultElemTy)
.value();
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return tosa::getConstTensor<bool>(rewriter, op, 0, {}).value();
case 8:
return tosa::getConstTensor<int8_t>(rewriter, op, 0, {}).value();
case 32:
return tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
case 64:
return tosa::getConstTensor<int64_t>(rewriter, op, 0, {}).value();
}
llvm_unreachable("Invalid integer width");
});
// Check: input <= threshold
auto cond = rewriter.create<tosa::GreaterEqualOp>(
op->getLoc(), RankedTensorType::get(selfShape, rewriter.getI1Type()),
threshold, self);
self = tosa::promoteType(rewriter, self, resultType);
grad = tosa::promoteType(rewriter, grad, resultType);
auto result = rewriter.create<tosa::SelectOp>(op->getLoc(), resultType,
cond.getResult(), zero, grad);
rewriter.replaceOp(op, {result.getResult()});
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -6705,6 +6829,7 @@ public:
INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp)
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
tosa::LogicalLeftShiftOp)
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
@ -6947,6 +7072,8 @@ public:
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -1707,9 +1707,17 @@ TOSA_CRASHING_SET = {
"ScatterSrcStaticModule_basic",
# Runtime op verification: Out of bounds access
"ReduceAllDimEmpty_basic",
# SmallVector unable to grow for ThresholdBackward1d
"ThresholdBackward1dFloatModule_basic",
"ThresholdBackward1dIntModule_basic",
"ThresholdBackward1dMixedModule_basic",
}
FX_IMPORTER_TOSA_CRASHING_SET = {
"GridSamplerBasic1_basic",
"GridSamplerBasic2_basic",
"GridSamplerBasic3_basic",
"GridSamplerBasic4_basic",
"ScatterSrcModule_basic",
"ScatterSrcStaticModule_basic",
"HBC_basic",
@ -1727,6 +1735,25 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"CosineSimilarityStaticBroadcastModule_basic",
"DropoutTrainStaticShapeModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseRreluTrainStaticModule_basic",
"IndexSelectRank0IdxModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandModule_basic",
"ReduceL3NormAllDimsModule_basic",
"ReduceL3NormKeepDimModule_basic",
"SliceCopy_Module_basic",
"Threshold1dIntModule_basic",
"Threshold2dIntModule_basic",
"Threshold3dIntModule_basic",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
@ -2296,8 +2323,6 @@ TOSA_PASS_SET = {
"TensorIntModule_basic",
"TensorLiteralModule_basic",
"TensorOpaqueLiteralModule_basic",
"TensorsConcatNegativeDimStaticModule_basic",
"TensorsConcatStaticModule_basic",
"TestF16Return_basic",
"TestMultipleTensorReturn_basic",
"Threshold1dFloatModule_basic",
@ -2363,7 +2388,6 @@ TOSA_PASS_SET = {
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
@ -2468,7 +2492,6 @@ MAKE_FX_TOSA_PASS_SET = (
"SplitWithSizesListUnpackModule_basic",
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
"MatmulStaticBroadcast_basic",
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
@ -2487,7 +2510,6 @@ MAKE_FX_TOSA_PASS_SET = (
"ElementwiseLogSigmoidModule_basic",
# failed to legalize operation 'torch.aten.rrelu_with_noise'
"ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic",
# incompatible return type failure for tosa.concat.
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
@ -3329,6 +3351,14 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
}
FX_IMPORTER_TOSA_XFAIL_SET = {
"AdaptiveMaxPool1dDimOneStatic_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MaxPool3dLargeDatadModule_basic",
"MaxPool3dModuleRandomSimple_basic",
"MaxPool3dModule_basic",
"MaxPool3dStaticModule_basic",
"ViewDtypeStaticModule_basic",
"Unfold_Module_Dynamic_basic",
"Unfold_Module_Rank_4",
@ -3474,7 +3504,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
"BroadcastDynamicDimModule_basic",
"BroadcastToModule_basic",
"CeilFloatModule_basic",
"CollapseAllDimensionsModule_basic",
"CollapseFullDynamicModule_basic",
@ -3509,7 +3538,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic",
"CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"CumsumInputDtypeInt32Module_basic",
"CumsumModule_basic",
"CumsumStaticModule_basic",
@ -3524,8 +3552,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"DeterminantModule_F32",
"DivFloatModule_basic",
"DivIntModule_basic",
"DropoutTrainModule_basic",
"DropoutTrainStaticShapeModule_basic",
"ElementwiseAcosIntModule_basic",
"ElementwiseAcosModule_basic",
"ElementwiseAcoshIntModule_basic",
@ -3545,11 +3571,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseAtanTensorIntModule_basic",
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampTensorFloatModule_basic",
@ -3590,12 +3612,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseUnaryIntModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"EqIntModule_basic",
"ExpandModule_basic",
"ExponentialModule_basic",
"FloatImplicitModule_basic",
"FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic",
"FullModuleInt2D_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
@ -3606,42 +3625,25 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"GtFloatIntModule_basic",
"GtIntModule_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
"IndexPut1DIntNonAccumulateModule_basic",
"IndexPut2DFloatAccumulateModule_basic",
"IndexPut2DFloatNonAccumulateModule_basic",
"IndexPut2DIntAccumulateModule_basic",
"IndexPut2DIntNonAccumulateModule_basic",
"IndexPut3DFloatAccumulateModule_basic",
"IndexPut3DFloatNonAccumulateModule_basic",
"IndexPut3DIntAccumulateModule_basic",
"IndexPut3DIntNonAccumulateModule_basic",
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin1DIntAccumulateModule_basic",
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin2DIntAccumulateModule_basic",
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic",
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexPutImpl2DFloatAccumulateModule_basic",
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl2DImplicitModule_basic",
"IndexPutImpl2DIndexModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"IndexSelectRank0IdxModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
@ -3656,8 +3658,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"MatmulBroadcastBatchDim_basic",
"MatmulStaticBroadcast_basic",
"MaskedScatterStaticBasic_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
@ -3689,17 +3690,16 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
"MaxPool3dWithIndicesStaticModule_basic",
"MeanDimEmptyDimModule_basic",
"MeanDimNoneDimModule_basic",
"MseLossMeanReductionModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic",
"MlGroupNormManualModule_basic",
"MlGroupNormModule_basic",
"MlLayerNormManualModule_basic",
"MlLayerNormModule_basic",
"MulFloatModule_basic",
"MulIntModule_basic",
"NativeBatchNorm1DModule_basic",
"NativeBatchNorm2DModule_basic",
"NativeBatchNorm3DModule_basic",
"NativeBatchNormNoneWeightModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"NativeGroupNormBackwardModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
@ -3741,14 +3741,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
"QuantizedSingleLayer_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"RandModule_basic",
"RandnDtypeDeviceModule_basic",
"RandnGeneratorF64Module_basic",
"RandnGeneratorModule_basic",
@ -3760,9 +3755,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReduceL1NormComplexModule_basic",
"ReduceL1NormWithDTypeModule_basic",
"ReduceL2NormComplexModule_basic",
"ReduceL3NormAllDimsModule_basic",
"ReduceL3NormKeepDimComplexModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
@ -3843,18 +3836,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"TensorsConcatPromoteDTypeModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"Threshold1dIntModule_basic",
"Threshold2dIntModule_basic",
"Threshold3dIntModule_basic",
"ThresholdBackward1dFloatModule_basic",
"ThresholdBackward1dIntModule_basic",
"ThresholdBackward1dMixedModule_basic",
"ThresholdBackward2dFloatModule_basic",
"ThresholdBackward2dIntModule_basic",
"ThresholdBackward2dMixedModule_basic",
"ThresholdBackward3dFloatModule_basic",
"ThresholdBackward3dIntModule_basic",
"ThresholdBackward3dMixedModule_basic",
"ToCopyWithDTypeFalsePinMemoryModule_basic",
"ToCopyWithDTypeModule_basic",
"TorchPrimLoopForLikeModule_basic",
@ -3863,10 +3845,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"TraceUnsignedIntModule_empty",
"TypeConversionI1ToF64Module_basic",
"TypeConversionI1ToI32Module_basic",
"UniformModule_basic",
"UniformNoCorrelationModule_basic",
"UniformStaticShapeModule_basic",
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dBackwardScalesNone_basic",
"UpSampleNearest2dBackward_basic",
@ -3875,9 +3853,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic",
"UpSampleNearest2d_basic",
"VarMeanBiasedModule_basic",
"VarMeanCorrectionNoneModule_basic",
"VarMeanUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic",
@ -3894,6 +3869,15 @@ ONNX_TOSA_CRASHING_SET = {
}
ONNX_TOSA_XFAIL_SET = {
"ElementwiseRreluWithNoiseEvalModule_basic",
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"RreluWithNoiseBackwardEvalModule_basic",
"RreluWithNoiseBackwardEvalStaticModule_basic",
"RreluWithNoiseBackwardTrainModule_basic",
"RreluWithNoiseBackwardTrainStaticModule_basic",
"RreluWithNoiseForwardBackwardModule_basic",
"Unfold_Module_Dynamic_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
@ -3937,12 +3921,10 @@ ONNX_TOSA_XFAIL_SET = {
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic",
"EinsumStaticModule_basic",
"ElementwiseFmaxModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic",
"ElementwiseNanToNumWithNoneModule_Basic",
"ElementwiseRad2DegIntModule_basic",
"ElementwiseRad2DegModule_basic",
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
@ -4106,7 +4088,6 @@ ONNX_TOSA_XFAIL_SET = {
"BoolIntConstantModule_basic",
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
"BoolTensorHandleSignless_basic",
"BroadcastDynamicDimModule_basic",
"BroadcastToModule_basic",
"BucketizeTensorFloatModule_basic",
@ -4123,10 +4104,6 @@ ONNX_TOSA_XFAIL_SET = {
"CollapseRank1DynamicModule_basic",
"CollapseStaticModule_basic",
"ConstantBoolParameterModule_basic",
"ConstantPad2dStaticModule_basic",
"ConstantPadNdModule_basic",
"ConstantPadNdPartialStaticModule_basic",
"ConstantPadNdStaticModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
"Conv1dModule_basic",
@ -4220,9 +4197,7 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
"ElementwiseAtenIsneginfOpModule_basic",
"ElementwiseAtenIsposinfOpModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
@ -4254,7 +4229,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorFloatModule_basic",
@ -4291,7 +4265,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorFloatModule_basic",
"ElementwiseMulTensorIntModule_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseOrTensorModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
@ -4579,8 +4552,6 @@ ONNX_TOSA_XFAIL_SET = {
"OnesLikeModule_falsePinMemory",
"OnesLikeModule_float",
"OnesLikeModule_int",
"PadModule_basic",
"PadWithNoneValModule_basic",
"PermuteNegativeIndexModule_basic",
"PixelShuffleModuleFullDynamic_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
@ -4688,7 +4659,6 @@ ONNX_TOSA_XFAIL_SET = {
"ReflectionPad2dModule_Right",
"ReflectionPad2dModule_Top",
"ReflectionPad2dModule_basic",
"RepeatModule_basic",
"ReplicationPad2dModule_basic",
"ReplicationPad2dModule_bottom0",
"ReplicationPad2dModule_left0",

View File

@ -2162,4 +2162,82 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
return %1 : !torch.vtensor<[4,2],si64>
}
}
// -----
// CHECK-LABEL: func.func @torch.aten.threshold_backward$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],si64> -> tensor<4xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_2]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1>
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<4xi1>, tensor<i64>, tensor<4xi64>) -> tensor<4xi64>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<4xi64> -> !torch.vtensor<[4],si64>
// CHECK: return %[[VAL_9]] : !torch.vtensor<[4],si64>
// CHECK: }
func.func @torch.aten.threshold_backward$basic(%arg0: !torch.vtensor<[4],si64>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> {
%int1 = torch.constant.int 1
%0 = torch.aten.threshold_backward %arg0, %arg1, %int1 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.threshold$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],si64> -> tensor<4x5xi64>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
// CHECK: %[[VAL_6:.*]] = tosa.greater %[[VAL_1]], %[[VAL_4]] : (tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi1>
// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_6]], %[[VAL_1]], %[[VAL_5]] : (tensor<4x5xi1>, tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi64>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4x5xi64> -> !torch.vtensor<[4,5],si64>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,5],si64>
// CHECK: }
func.func @torch.aten.threshold$basic(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> {
%float5.000000e-01 = torch.constant.float 5.000000e-01
%int2 = torch.constant.int 2
%0 = torch.aten.threshold %arg0, %float5.000000e-01, %int2 : !torch.vtensor<[4,5],si64>, !torch.float, !torch.int -> !torch.vtensor<[4,5],si64>
return %0 : !torch.vtensor<[4,5],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_and$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1>
// CHECK: %[[VAL_4:.*]] = tosa.logical_and %[[VAL_3]], %[[VAL_2]] : (tensor<4x5xi1>, tensor<4x5xi1>) -> tensor<4x5xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,5],i1>
// CHECK: }
func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> {
%0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1>
return %0 : !torch.vtensor<[4,5],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.uniform$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64>
// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
// CHECK: }
func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
%float1.000000e00 = torch.constant.float 1.000000e+00
%float1.000000e01 = torch.constant.float 1.000000e+01
%none = torch.constant.none
%0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64>
return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
}