build: update llvm tag to de3f0f7f (#1789)

Credit to @vivekkhandelwal1 for finding the necessary changes.

Summary of changes:

 - Switch Tosa_IntArrayAttr[N], Tosa_IntArrayAttrUpto[N] to DenseI64ArrayAttr.

 - Replace kNoIterationLimit with kNoLimit. (https://reviews.llvm.org/D140525)

 - Add dependency on MhloPasses when MHLO is enabled

 - Specify result type when using mhlo::DotOp
pull/1827/head
Ashay Rane 2023-01-10 17:07:19 -06:00 committed by GitHub
parent 0979df6589
commit 0faba6d2fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 107 additions and 93 deletions

@ -1 +1 @@
Subproject commit 7ccbb4dff10efe6c26219204e361ddb0264938b8 Subproject commit de3f0f7fa0c7b902dde840913db7e773a02c4173

2
externals/mlir-hlo vendored

@ -1 +1 @@
Subproject commit 8c703fabd60d4447bc86f432446e9ad0eacab600 Subproject commit 2c8823d255a777d3053ef891f4dbeea1c32819f4

View File

@ -18,7 +18,9 @@ set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchConversionToMLProgram TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils) TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND linked_libs TorchMLIRTorchToMhlo) list(APPEND linked_libs
MhloPasses
TorchMLIRTorchToMhlo)
endif() endif()
add_mlir_library(TorchMLIRConversionPasses add_mlir_library(TorchMLIRConversionPasses

View File

@ -11,7 +11,7 @@
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_MHLO
#include "mhlo/transforms/passes.h" #include "mhlo/transforms/passes.h"
#include "mlir-hlo/Transforms/passes.h" #include "transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_MHLO #endif // TORCH_MLIR_ENABLE_MHLO
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
@ -37,7 +37,7 @@ void mlir::torch::registerConversionPasses() {
return mlir::mhlo::createLegalizeHloToLinalgPass(); return mlir::mhlo::createLegalizeHloToLinalgPass();
}); });
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::createSymbolicShapeOptimizationPass(); return mlir::mhlo::createSymbolicShapeOptimizationPass();
}); });
#endif // TORCH_MLIR_ENABLE_MHLO #endif // TORCH_MLIR_ENABLE_MHLO
} }

View File

@ -216,7 +216,10 @@ public:
} }
if (lhsRank <= 2 && rhsRank <= 2) { if (lhsRank <= 2 && rhsRank <= 2) {
output = rewriter.create<mhlo::DotOp>(op->getLoc(), lhs, rhs, nullptr); auto tensorType =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
output = rewriter.create<mhlo::DotOp>(op->getLoc(), tensorType, lhs, rhs,
nullptr);
return success(); return success();
} }

View File

@ -881,7 +881,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
newOutputTy), newOutputTy),
self, rewriter.getI64ArrayAttr(newOutputShape)); self, rewriter.getDenseI64ArrayAttr(newOutputShape));
rewriter.replaceOpWithNewOp<tensor::CastOp>( rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
@ -1076,7 +1076,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
lhsBroadcastedTy), lhsBroadcastedTy),
lhs, rewriter.getI64ArrayAttr(lhsBroadcastedShape)); lhs, rewriter.getDenseI64ArrayAttr(lhsBroadcastedShape));
auto rankBroadcastedRhs = auto rankBroadcastedRhs =
rhsRank == maxInputRank rhsRank == maxInputRank
@ -1085,7 +1085,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
rhsBroadcastedTy), rhsBroadcastedTy),
rhs, rewriter.getI64ArrayAttr(rhsBroadcastedShape)); rhs, rewriter.getDenseI64ArrayAttr(rhsBroadcastedShape));
// TOSA matmul is performed on two 3D inputs and generates a 3D output. // TOSA matmul is performed on two 3D inputs and generates a 3D output.
// Lower ranked tensors are dim-1 reshaped up to 3D // Lower ranked tensors are dim-1 reshaped up to 3D
@ -1113,7 +1113,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
newType), newType),
tensor, rewriter.getI64ArrayAttr(newShape)); tensor, rewriter.getDenseI64ArrayAttr(newShape));
}; };
// Where broadcasting is required in one or more batch dims, the following // Where broadcasting is required in one or more batch dims, the following
@ -1303,7 +1303,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
newLhsType), newLhsType),
lhsReshapeInput, rewriter.getI64ArrayAttr(newLhsShape)); lhsReshapeInput, rewriter.getDenseI64ArrayAttr(newLhsShape));
SmallVector<int64_t> transposedRhsShape; SmallVector<int64_t> transposedRhsShape;
SmallVector<int32_t> transposedRhsDims; SmallVector<int32_t> transposedRhsDims;
@ -1375,7 +1375,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
newRhsType), newRhsType),
transposedRhsValue, rewriter.getI64ArrayAttr(newRhsShape)); transposedRhsValue, rewriter.getDenseI64ArrayAttr(newRhsShape));
} }
auto matmulLhsShape = makeShapeTorchCompatible( auto matmulLhsShape = makeShapeTorchCompatible(
@ -1506,7 +1506,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
reshapedOpType), reshapedOpType),
mmOpResult, rewriter.getI64ArrayAttr(reshapedOpShape)); mmOpResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape));
if (opNeedsTranspose) { if (opNeedsTranspose) {
@ -1915,9 +1915,9 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
.create<tosa::Conv2DOp>(op->getLoc(), .create<tosa::Conv2DOp>(op->getLoc(),
getTypeConverter()->convertType(convOpTy), getTypeConverter()->convertType(convOpTy),
transposedInput, transposedWeight, bias, transposedInput, transposedWeight, bias,
rewriter.getI64ArrayAttr(padding), rewriter.getDenseI64ArrayAttr(padding),
rewriter.getI64ArrayAttr(stride), rewriter.getDenseI64ArrayAttr(stride),
rewriter.getI64ArrayAttr(dilation)) rewriter.getDenseI64ArrayAttr(dilation))
.getResult(); .getResult();
std::optional<Value> nhwcToNchwTransposeConst = std::optional<Value> nhwcToNchwTransposeConst =
@ -1979,7 +1979,7 @@ LogicalResult ConvertAtenOp<AtenReshapeOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(newType), self, op, getTypeConverter()->convertType(newType), self,
rewriter.getI64ArrayAttr(newShape)); rewriter.getDenseI64ArrayAttr(newShape));
return success(); return success();
} }
@ -2078,7 +2078,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
outTensorType.getElementType()); outTensorType.getElementType());
result = rewriter.create<tosa::ReshapeOp>( result = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), newType, toBcast, rewriter.getI64ArrayAttr(newShape)); op->getLoc(), newType, toBcast,
rewriter.getDenseI64ArrayAttr(newShape));
return success(); return success();
}; };
@ -2203,8 +2204,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
sumDiv, rewriter.getI64IntegerAttr(i)); sumDiv, rewriter.getI64IntegerAttr(i));
} }
return rewriter.create<tosa::ReshapeOp>(op.getLoc(), outType, sumDiv, return rewriter.create<tosa::ReshapeOp>(
rewriter.getI64ArrayAttr(outShape)); op.getLoc(), outType, sumDiv, rewriter.getDenseI64ArrayAttr(outShape));
}; };
// TOSA has integer Div so, compute reciprocal of element count to be used in // TOSA has integer Div so, compute reciprocal of element count to be used in
@ -2260,11 +2261,11 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
Value weightVal = rewriter.create<tosa::ReshapeOp>( Value weightVal = rewriter.create<tosa::ReshapeOp>(
op.getLoc(), weightAndMeanBcastType, adaptor.getWeight(), op.getLoc(), weightAndMeanBcastType, adaptor.getWeight(),
rewriter.getI64ArrayAttr(weightAndBiasBcastShape)); rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape));
Value biasVal = rewriter.create<tosa::ReshapeOp>( Value biasVal = rewriter.create<tosa::ReshapeOp>(
op.getLoc(), weightAndMeanBcastType, adaptor.getBias(), op.getLoc(), weightAndMeanBcastType, adaptor.getBias(),
rewriter.getI64ArrayAttr(weightAndBiasBcastShape)); rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape));
double eps; double eps;
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
@ -2365,8 +2366,9 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape), auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape),
selfType.getElementType()); selfType.getElementType());
auto reshapeOp = rewriter.create<tosa::ReshapeOp>( auto reshapeOp =
op.getLoc(), newType, adaptor.getSelf(), rewriter.getI64ArrayAttr(newShape)); rewriter.create<tosa::ReshapeOp>(op.getLoc(), newType, adaptor.getSelf(),
rewriter.getDenseI64ArrayAttr(newShape));
rewriter.replaceOpWithNewOp<tensor::CastOp>( rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), reshapeOp); op, getTypeConverter()->convertType(op.getType()), reshapeOp);
@ -2530,7 +2532,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
rewriter.getI64ArrayAttr(outShape)); rewriter.getDenseI64ArrayAttr(outShape));
return success(); return success();
} }
@ -2603,7 +2605,7 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
rewriter.getI64ArrayAttr(outShape)); rewriter.getDenseI64ArrayAttr(outShape));
return success(); return success();
} }
@ -2838,7 +2840,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
op->getLoc(), op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(newWeightShape), RankedTensorType::get(makeShapeLLVMCompatible(newWeightShape),
weightType.getElementType()), weightType.getElementType()),
weight, rewriter.getI64ArrayAttr(newWeightShape)); weight, rewriter.getDenseI64ArrayAttr(newWeightShape));
int64_t numIndices = 1; int64_t numIndices = 1;
if (indicesType.hasStaticShape()) { if (indicesType.hasStaticShape()) {
@ -2853,7 +2855,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
op->getLoc(), op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape), RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape),
indicesType.getElementType()), indicesType.getElementType()),
indices, rewriter.getI64ArrayAttr(newIndicesShape)); indices, rewriter.getDenseI64ArrayAttr(newIndicesShape));
auto castIndices = rewriter.create<tosa::CastOp>( auto castIndices = rewriter.create<tosa::CastOp>(
op->getLoc(), op->getLoc(),
@ -2870,7 +2872,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, outType, gatherOp, op, outType, gatherOp,
rewriter.getI64ArrayAttr(makeShapeTorchCompatible(outType.getShape()))); rewriter.getDenseI64ArrayAttr(
makeShapeTorchCompatible(outType.getShape())));
return success(); return success();
} }
@ -2960,7 +2963,7 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
} }
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim); auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim);
auto prunedShapeAttr = rewriter.getI64ArrayAttr(prunedShape); auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);
Value reduceMax = rewriter.create<tosa::ReduceMaxOp>( Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
op->getLoc(), op->getLoc(),
@ -2975,7 +2978,7 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
if (argMax.getType() != indicesType) { if (argMax.getType() != indicesType) {
argMax = rewriter.create<tosa::ReshapeOp>( argMax = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesType, argMax, op->getLoc(), indicesType, argMax,
rewriter.getI64ArrayAttr(reducedShape)); rewriter.getDenseI64ArrayAttr(reducedShape));
} }
if (!keepDim) { if (!keepDim) {
@ -3043,8 +3046,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<tosa::SliceOp>( rewriter.replaceOpWithNewOp<tosa::SliceOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
rewriter.getI64ArrayAttr(startSlice), rewriter.getDenseI64ArrayAttr(startSlice),
rewriter.getI64ArrayAttr(sizeSlice)); rewriter.getDenseI64ArrayAttr(sizeSlice));
return success(); return success();
} }
@ -3427,8 +3430,9 @@ public:
// function also transposes inputs. // function also transposes inputs.
virtual LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor, virtual LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
Value &input, ArrayAttr &kernel, Value &input, DenseI64ArrayAttr &kernel,
ArrayAttr &stride, ArrayAttr &pad, DenseI64ArrayAttr &stride,
DenseI64ArrayAttr &pad,
Type &outputTy) const { Type &outputTy) const {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented pooling input parsing function"); op, "Unimplemented pooling input parsing function");
@ -3503,7 +3507,7 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value input; Value input;
ArrayAttr kernel, stride, pad; DenseI64ArrayAttr kernel, stride, pad;
Type outputTy; Type outputTy;
// Attempts to read input and kernel parameters, or synthesize them in the // Attempts to read input and kernel parameters, or synthesize them in the
@ -3540,8 +3544,9 @@ public:
using OpAdaptor = typename AtenOpT::Adaptor; using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor, LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &input, ConversionPatternRewriter &rewriter, Value &input,
ArrayAttr &kernel, ArrayAttr &stride, DenseI64ArrayAttr &kernel,
ArrayAttr &pad, Type &outputTy) const override { DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override {
auto inputXchw = adaptor.getSelf(); auto inputXchw = adaptor.getSelf();
auto inputTy = inputXchw.getType().template cast<RankedTensorType>(); auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
if (!inputTy) if (!inputTy)
@ -3603,12 +3608,12 @@ public:
input = input =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc( ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc(
op, rewriter, inputXchw); op, rewriter, inputXchw);
kernel = rewriter.getI64ArrayAttr(kernelDims); kernel = rewriter.getDenseI64ArrayAttr(kernelDims);
stride = rewriter.getI64ArrayAttr({strideH, strideW}); stride = rewriter.getDenseI64ArrayAttr({strideH, strideW});
// Adaptive pooling does unit dilation and zero pad. // Adaptive pooling does unit dilation and zero pad.
pad = rewriter.getI64ArrayAttr({0, 0, 0, 0}); pad = rewriter.getDenseI64ArrayAttr({0, 0, 0, 0});
outputTy = outputTy = RankedTensorType::get(makeShapeLLVMCompatible(outputShape),
RankedTensorType::get(makeShapeLLVMCompatible(outputShape), inputElemTy); inputElemTy);
return success(); return success();
} }
@ -3643,8 +3648,9 @@ static Type getOutputTypeForNonAdaptivePoolingOp(
template <typename AtenOpT, typename tosaOp> template <typename AtenOpT, typename tosaOp>
static LogicalResult getOutputTypeAndPoolingParameters( static LogicalResult getOutputTypeAndPoolingParameters(
AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw, AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy, ArrayAttr &kernel, SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
ArrayAttr &stride, ArrayAttr &pad) { DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
DenseI64ArrayAttr &pad) {
RankedTensorType inputTy = inputXchw.getType().cast<RankedTensorType>(); RankedTensorType inputTy = inputXchw.getType().cast<RankedTensorType>();
if (!inputTy) if (!inputTy)
@ -3669,9 +3675,9 @@ static LogicalResult getOutputTypeAndPoolingParameters(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Non-const padding factor for pooling op unsupported"); op, "Non-const padding factor for pooling op unsupported");
kernel = rewriter.getI64ArrayAttr(kernelSizeInts); kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
stride = rewriter.getI64ArrayAttr(strideInts); stride = rewriter.getDenseI64ArrayAttr(strideInts);
pad = rewriter.getI64ArrayAttr( pad = rewriter.getDenseI64ArrayAttr(
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}); {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});
// FIXME: add ceil_mode support. // FIXME: add ceil_mode support.
@ -3696,10 +3702,12 @@ public:
tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp; tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp;
LogicalResult processInputs(AtenMaxPool2dOp op, OpAdaptor adaptor, LogicalResult processInputs(AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &input, ConversionPatternRewriter &rewriter, Value &input,
ArrayAttr &kernel, ArrayAttr &stride, DenseI64ArrayAttr &kernel,
ArrayAttr &pad, Type &outputTy) const override { DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override {
SmallVector<int64_t, 2> dilationArray; SmallVector<int64_t, 2> dilationArray;
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationArray))) if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilationArray)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Non-const dilation for pooling op unsupported."); op, "Non-const dilation for pooling op unsupported.");
// TOSA pooling only supports unit dilation. // TOSA pooling only supports unit dilation.
@ -3729,8 +3737,9 @@ public:
tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp; tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp;
LogicalResult processInputs(AtenAvgPool2dOp op, OpAdaptor adaptor, LogicalResult processInputs(AtenAvgPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &input, ConversionPatternRewriter &rewriter, Value &input,
ArrayAttr &kernel, ArrayAttr &stride, DenseI64ArrayAttr &kernel,
ArrayAttr &pad, Type &outputTy) const override { DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override {
SmallVector<int64_t, 2> dilationArray{1, 1}; SmallVector<int64_t, 2> dilationArray{1, 1};
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp, if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
tosa::AvgPool2dOp>( tosa::AvgPool2dOp>(

View File

@ -151,7 +151,7 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
auto indicesChosenAxis = tosa::CreateOpAndInfer<tosa::ReshapeOp>( auto indicesChosenAxis = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesOneDimShape, indexType.getElementType()), GetTypeFromTensorShape(indicesOneDimShape, indexType.getElementType()),
indexValue, rewriter.getI64ArrayAttr(indicesOneDimShape)); indexValue, rewriter.getDenseI64ArrayAttr(indicesOneDimShape));
SmallVector<Value> concatInputs; SmallVector<Value> concatInputs;
for (auto dim = 0; dim < paramsRank; dim++) { for (auto dim = 0; dim < paramsRank; dim++) {
@ -312,14 +312,14 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
auto tosaValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>( auto tosaValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaValuesShape, paramsType.getElementType()), GetTypeFromTensorShape(tosaValuesShape, paramsType.getElementType()),
paramsValue, rewriter.getI64ArrayAttr(tosaValuesShape)); paramsValue, rewriter.getDenseI64ArrayAttr(tosaValuesShape));
// %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> // %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) ->
// tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix. // tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix.
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>( auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesValue, rewriter.getI64ArrayAttr(indicesMatrixShape)); indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
SmallVector<int32_t> flattenedCoeffVec; // [12,3,1] SmallVector<int32_t> flattenedCoeffVec; // [12,3,1]
// flattenedCoeffVec = [4,3,1] // flattenedCoeffVec = [4,3,1]
@ -367,7 +367,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
rewriter, op->getLoc(), rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()), GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()),
flattenedIndicesReduceOp.getResult(), flattenedIndicesReduceOp.getResult(),
rewriter.getI64ArrayAttr(tosaIndicesShape)); rewriter.getDenseI64ArrayAttr(tosaIndicesShape));
// Now the gather op itself // Now the gather op itself
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> // %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
@ -384,7 +384,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
// %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> // %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
return tosa::CreateOpAndInfer<tosa::ReshapeOp>( return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(), rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(),
rewriter.getI64ArrayAttr(resultType.getShape())) rewriter.getDenseI64ArrayAttr(resultType.getShape()))
.getResult(); .getResult();
} }
@ -446,7 +446,7 @@ std::optional<Value> convertReduceOpCommon(
if (!keep_dims) { if (!keep_dims) {
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>( auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), output_type, val, rewriter, op->getLoc(), output_type, val,
rewriter.getI64ArrayAttr(output_shape)); rewriter.getDenseI64ArrayAttr(output_shape));
val = reshape_op.getResult(); val = reshape_op.getResult();
} }
} }

View File

@ -32,9 +32,9 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
rewriter, op->getLoc(), output_type, input_val, rewriter, op->getLoc(), output_type, input_val,
rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)), rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)), rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}), rewriter.getDenseI32ArrayAttr({multiplier}),
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
rewriter.getBoolAttr(false)); rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false));
return rescale_op.getResult(); return rescale_op.getResult();
} }
@ -85,8 +85,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>( auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, conv_val, rewriter, op->getLoc(), output_type, conv_val,
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
rewriter.getI32ArrayAttr({multiplier}), rewriter.getDenseI32ArrayAttr({multiplier}),
rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
rewriter.getBoolAttr(true), rewriter.getBoolAttr(false)); rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
return rescale_op.getResult(); return rescale_op.getResult();
@ -121,8 +121,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>( auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, conv_val, rewriter, op->getLoc(), output_type, conv_val,
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
rewriter.getI32ArrayAttr(multiplier_arr), rewriter.getDenseI32ArrayAttr(multiplier_arr),
rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), rewriter.getDenseI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true)); rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
return rescale_op.getResult(); return rescale_op.getResult();

View File

@ -3667,7 +3667,7 @@ public:
GreedyRewriteConfig config; GreedyRewriteConfig config;
config.useTopDownTraversal = true; config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit; config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) { config))) {

View File

@ -194,7 +194,7 @@ class SimplifyDtypeCalculationsPass
// A single linear scan should suffice. // A single linear scan should suffice.
GreedyRewriteConfig config; GreedyRewriteConfig config;
config.useTopDownTraversal = true; config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit; config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) { config))) {
return signalPassFailure(); return signalPassFailure();

View File

@ -384,7 +384,7 @@ class SimplifyShapeCalculationsPass
// A single linear scan should suffice. // A single linear scan should suffice.
GreedyRewriteConfig config; GreedyRewriteConfig config;
config.useTopDownTraversal = true; config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit; config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) { config))) {
return signalPassFailure(); return signalPassFailure();

View File

@ -224,7 +224,7 @@ func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
// CHECK: %[[ARG3:.*]] = torch.constant.int 0 // CHECK: %[[ARG3:.*]] = torch.constant.int 0
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list<int> // CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32> // CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808>} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -246,7 +246,7 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_sum"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32> // CHECK: %[[REDUCE2:.*]] = "tosa.reduce_sum"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32>
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_sum"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32> // CHECK: %[[REDUCE3:.*]] = "tosa.reduce_sum"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32>
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_sum"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32> // CHECK: %[[REDUCE4:.*]] = "tosa.reduce_sum"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xf32>) -> tensor<1xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = array<i64: 1>} : (tensor<1x1x1x1xf32>) -> tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
@ -264,7 +264,7 @@ func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_all"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> // CHECK: %[[REDUCE2:.*]] = "tosa.reduce_all"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_all"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> // CHECK: %[[REDUCE3:.*]] = "tosa.reduce_all"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_all"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> // CHECK: %[[REDUCE4:.*]] = "tosa.reduce_all"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = array<i64: 1>} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
@ -280,7 +280,7 @@ func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.
// CHECK: %[[ARG1:.*]] = torch.constant.int 0 // CHECK: %[[ARG1:.*]] = torch.constant.int 0
// CHECK: %[[ARG2:.*]] = torch.constant.bool false // CHECK: %[[ARG2:.*]] = torch.constant.bool false
// CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1> // CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808>} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1>
func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> { func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
@ -299,7 +299,7 @@ func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !to
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_any"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> // CHECK: %[[REDUCE2:.*]] = "tosa.reduce_any"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_any"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> // CHECK: %[[REDUCE3:.*]] = "tosa.reduce_any"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_any"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> // CHECK: %[[REDUCE4:.*]] = "tosa.reduce_any"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = array<i64: 1>} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
@ -467,7 +467,7 @@ func.func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1 // CHECK: %[[VAL_2:.*]] = torch.constant.int -1
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int> // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [-1]} : (tensor<?x?x?x?xf32>) -> tensor<?xf32> // CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array<i64: -1>} : (tensor<?x?x?x?xf32>) -> tensor<?xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32>
// CHECK: } // CHECK: }
@ -489,10 +489,10 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to
// CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true // CHECK: %[[VAL_6:.*]] = torch.constant.bool true
// CHECK: %[[VAL_7:.*]] = torch.constant.bool false // CHECK: %[[VAL_7:.*]] = torch.constant.bool false
// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array<i64: 4, 1>} : (tensor<4xf32>) -> tensor<4x1xf32>
// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = array<i64: 4, 1>} : (tensor<4xf32>) -> tensor<4x1xf32>
// CHECK: %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = array<i64: 4, 1>} : (tensor<4xf32>) -> tensor<4x1xf32>
// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array<i64: 4, 1>} : (tensor<4xf32>) -> tensor<4x1xf32>
// CHECK: %[[VAL_12:.*]] = "tosa.const"() {value = dense<9.99999974E-6> : tensor<f32>} : () -> tensor<f32> // CHECK: %[[VAL_12:.*]] = "tosa.const"() {value = dense<9.99999974E-6> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_13:.*]] = "tosa.sub"(%[[VAL_1]], %[[VAL_8]]) : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_13:.*]] = "tosa.sub"(%[[VAL_1]], %[[VAL_8]]) : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32>
// CHECK: %[[VAL_14:.*]] = "tosa.add"(%[[VAL_9]], %[[VAL_12]]) : (tensor<4x1xf32>, tensor<f32>) -> tensor<4x1xf32> // CHECK: %[[VAL_14:.*]] = "tosa.add"(%[[VAL_9]], %[[VAL_12]]) : (tensor<4x1xf32>, tensor<f32>) -> tensor<4x1xf32>
@ -521,7 +521,7 @@ func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_2:.*]] = torch.constant.int 4
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [10, 3, 216, 4]} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32> // CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array<i64: 10, 3, 216, 4>} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32>
// CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32> // CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32>
@ -551,17 +551,17 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor
// CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> // CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
// CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32>
// CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = array<i64: 5, 1, 1, 1>} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> // CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
// CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32>
// CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) {new_shape = array<i64: 5, 1, 1, 1>} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = array<i64: 1, 2, 2, 3>} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = array<i64: 1, 2, 2, 3>} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
// CHECK: %[[VAL_26:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32> // CHECK: %[[VAL_26:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor<f32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor<f32>) -> tensor<5x1x1x1xf32>
@ -681,7 +681,7 @@ func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 3, 1]} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> // CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array<i64: 4, 3, 1>} : (tensor<4x3xi32>) -> tensor<4x3x1xi32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32>
// CHECK: } // CHECK: }
@ -698,7 +698,7 @@ func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !to
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1 // CHECK: %[[VAL_2:.*]] = torch.constant.int -1
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 3, 1]} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> // CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array<i64: 4, 3, 1>} : (tensor<4x3xi32>) -> tensor<4x3x1xi32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32>
// CHECK: } // CHECK: }
@ -778,7 +778,7 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_11:.*]] = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: %[[VAL_11:.*]] = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> // CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) {kernel = [7, 7], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> // CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) {kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> // CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> // CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
@ -809,7 +809,7 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_I2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> // CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> // CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [3, 2, 1]} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> // CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> // CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32> // CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
@ -927,18 +927,18 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1 // CHECK: %[[VAL_4:.*]] = torch.constant.int -1
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false // CHECK: %[[VAL_5:.*]] = torch.constant.bool false
// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_3]]) : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> // CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_3]]) : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32>
// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = [1, 4, 2, 1]} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = array<i64: 1, 4, 2, 1>} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32>
// CHECK: %[[VAL_8:.*]] = "tosa.const"() {value = dense<0> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_8:.*]] = "tosa.const"() {value = dense<0> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() {value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_9:.*]] = "tosa.const"() {value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.concat"(%[[VAL_8]], %[[VAL_9]], %[[VAL_7]]) {axis = 3 : i64} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> // CHECK: %[[VAL_10:.*]] = "tosa.concat"(%[[VAL_8]], %[[VAL_9]], %[[VAL_7]]) {axis = 3 : i64} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32>
// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [1, 12, 1]} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> // CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array<i64: 1, 12, 1>} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32>
// CHECK: %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_10]]) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_10]]) {new_shape = array<i64: 8, 3>} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32>
// CHECK: %[[VAL_13:.*]] = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: %[[VAL_14:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_13]]) {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_14:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_13]]) {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32>
// CHECK: %[[VAL_15:.*]] = "tosa.reduce_sum"(%[[VAL_14]]) {axis = 1 : i64} : (tensor<8x3xi32>) -> tensor<8x1xi32> // CHECK: %[[VAL_15:.*]] = "tosa.reduce_sum"(%[[VAL_14]]) {axis = 1 : i64} : (tensor<8x3xi32>) -> tensor<8x1xi32>
// CHECK: %[[VAL_16:.*]] = "tosa.reshape"(%[[VAL_15]]) {new_shape = [1, 8]} : (tensor<8x1xi32>) -> tensor<1x8xi32> // CHECK: %[[VAL_16:.*]] = "tosa.reshape"(%[[VAL_15]]) {new_shape = array<i64: 1, 8>} : (tensor<8x1xi32>) -> tensor<1x8xi32>
// CHECK: %[[VAL_17:.*]] = "tosa.gather"(%[[VAL_11]], %[[VAL_16]]) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> // CHECK: %[[VAL_17:.*]] = "tosa.gather"(%[[VAL_11]], %[[VAL_16]]) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32>
// CHECK: %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) {new_shape = [1, 4, 2]} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> // CHECK: %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) {new_shape = array<i64: 1, 4, 2>} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32>
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
// CHECK: return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32>
// CHECK: } // CHECK: }