From f9f97ea1842df98b16a4c58d1bd12fa3f7f6dced Mon Sep 17 00:00:00 2001 From: Anup Gangwar Date: Thu, 3 Feb 2022 14:08:19 -0800 Subject: [PATCH] * [tosa] Support for AtenNativeLayerNormOp * [tosa] Support for AtenPermuteOp Signed-off-by: Anup Gangwar --- e2e_testing/torchscript/batchnorm.py | 25 ++ e2e_testing/torchscript/xfail_sets.py | 4 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 311 +++++++++++++++++---- test/Conversion/TorchToTosa/basic.mlir | 71 +++++ 4 files changed, 355 insertions(+), 56 deletions(-) diff --git a/e2e_testing/torchscript/batchnorm.py b/e2e_testing/torchscript/batchnorm.py index 39071b99f..54e7c8b56 100644 --- a/e2e_testing/torchscript/batchnorm.py +++ b/e2e_testing/torchscript/batchnorm.py @@ -114,6 +114,31 @@ def NativeLayerNormModule_basic(module, tu: TestUtils): # ============================================================================== +class NativeLayerNormModule4D(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 2, 2, 3], torch.float32, True), + ([2, 2, 3], torch.float32, True), + ([2, 2, 3], torch.float32, True), + ]) + def forward(self, x, weight, bias): + list = [2, 2, 3] + return torch.ops.aten.native_layer_norm( + x, list, weight, bias, eps=0.5)[0] + + +@register_test_case(module_factory=lambda: NativeLayerNormModule4D()) +def NativeLayerNormModule4D_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3)) + + +# ============================================================================== + + class LayerNormModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 12f1e136f..17d85ddac 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -92,4 +92,8 @@ TOSA_PASS_SET = { "SquareModule_basic", "MaxPool2dStaticModule_basic", "ResNet18StaticModule_basic", + "NativeLayerNormModule4D_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index ddf20f6b7..ec5024d7e 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1793,48 +1793,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -// Normalization ops perform elementwise ops of a single mean/stdev value -// against the feature map and because input is NCHW, the rank-1 value must be -// reshaped so it sits on the same dim as 'C'. -static LogicalResult reshapeToNormInputDim(Operation *op, - ConversionPatternRewriter &rewriter, - TypeConverter *converter, - Type outType, const Value toBcast, - Value &result) { - RankedTensorType toBcastType = toBcast.getType().dyn_cast(); - if (toBcastType.getRank() > 1) - op->emitError("Rank cannot be more than 1"); - - RankedTensorType outTensorType = outType.cast(); - SmallVector newShape = {toBcastType.getShape()[0]}; - for (auto i = 2; i < outTensorType.getRank(); ++i) - newShape.push_back(1); - auto newType = - RankedTensorType::get(newShape, outTensorType.getElementType()); - - result = rewriter.create( - op->getLoc(), converter->convertType(newType), toBcast, - rewriter.getI64ArrayAttr(newShape)); - - return success(); -} - -// This lowering is based on the TensorFlow to TOSA lowering. -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenBatchNormOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - // Not a ranked tensor output - if (!adaptor.input().getType().dyn_cast()) - return op.emitError("Only ranked tensor types are supported"); - - auto outType = getTypeConverter()->convertType(op.getType()); - - // FIXME: Handle training, momentum and cudnn_enabled - if (op.momentum().getType().isa()) - op.emitError("Unsupported None for momentum"); - +Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter, + Type outType, Value input, Value variance, Value eps, + Value mean, Value weight, Value bias) { // For PyTorch: // scale = gamma = weight // offset = beta = bias @@ -1858,11 +1819,76 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // op5 = mul(op4, bscale) // op6 = add(op5, boffset) + auto op1SubInputMean = + rewriter.create(op->getLoc(), outType, input, mean); + + auto op2AddVarEpsilon = rewriter.create( + op->getLoc(), variance.getType(), variance, eps); + + auto op3RsqrtOp2 = rewriter.create( + op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult()); + + auto op4MulOp1Op3 = rewriter.create(op->getLoc(), outType, + op1SubInputMean.getResult(), + op3RsqrtOp2.getResult(), 0); + + auto op5MulOp4Scale = rewriter.create( + op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, 0); + + return rewriter + .create(op->getLoc(), outType, op5MulOp4Scale.getResult(), + bias) + .getResult(); +} + +// This lowering is based on the TensorFlow to TOSA lowering. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenBatchNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a ranked tensor output + if (!adaptor.input().getType().dyn_cast()) + return op.emitError("Only ranked tensor types are supported"); + + auto outType = getTypeConverter()->convertType(op.getType()); + + // Note: cudnn_enabled is not handled. + + // FIXME: Handle training and momentum. + if (op.momentum().getType().isa()) + op.emitError("Unsupported None for momentum"); + auto meanType = adaptor.running_mean().getType().dyn_cast(); auto varianceType = adaptor.running_var().getType().dyn_cast(); if (!varianceType || !meanType) return op.emitError("Only ranked tensor types are supported"); + // Normalization ops perform elementwise ops of a single mean/stdev value + // against the feature map and because input is NCHW, the rank-1 value must be + // reshaped so it sits on the same dim as 'C'. + auto reshapeToNormInputDim = [&](Operation *op, + ConversionPatternRewriter &rewriter, + TypeConverter *converter, Type outType, + const Value toBcast, Value &result) { + RankedTensorType toBcastType = + toBcast.getType().dyn_cast(); + if (toBcastType.getRank() > 1) + op->emitError("Rank cannot be more than 1"); + + RankedTensorType outTensorType = outType.cast(); + SmallVector newShape = {toBcastType.getShape()[0]}; + for (auto i = 2; i < outTensorType.getRank(); ++i) + newShape.push_back(1); + auto newType = + RankedTensorType::get(newShape, outTensorType.getElementType()); + + result = rewriter.create( + op->getLoc(), newType, toBcast, rewriter.getI64ArrayAttr(newShape)); + + return success(); + }; + Value meanVal, varianceVal, weightVal, biasVal; assert(meanType.getNumElements() != 0 && varianceType.getNumElements() != 0); if (failed(reshapeToNormInputDim(op.getOperation(), rewriter, @@ -1892,26 +1918,164 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto epsilonConst = mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); - auto op1SubInputMean = rewriter.create(op->getLoc(), outType, - adaptor.input(), meanVal); + auto batchNorm = + computeBatchNorm(op, rewriter, outType, adaptor.input(), varianceVal, + epsilonConst, meanVal, weightVal, biasVal); - auto op2AddVarEpsilon = rewriter.create( - op->getLoc(), varianceVal.getType(), varianceVal, epsilonConst); + rewriter.replaceOp(op, {batchNorm}); - auto op3RsqrtOp2 = rewriter.create( - op->getLoc(), varianceVal.getType(), op2AddVarEpsilon.getResult()); + return success(); +} - auto op4MulOp1Op3 = rewriter.create(op->getLoc(), outType, - op1SubInputMean.getResult(), - op3RsqrtOp2.getResult(), 0); +// This lowering is loosely based on Torch to LinAlg lowering. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenNativeLayerNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { - auto op5MulOp4Scale = rewriter.create( - op->getLoc(), outType, op4MulOp1Op3.getResult(), weightVal, 0); + // The key difference from BatchNorm is that a specified set of dims + // (normalized_shape) are chosen to compute the mean and variance from input. + // Where as in BatchNorm the mean and variance are operands. tosa::ReduceSumOp + // is used to sum up the these dims for mean and for variance. The results + // eventually being reshaped for broadcasting. - auto op6AddOp5Offset = rewriter.create( - op->getLoc(), outType, op5MulOp4Scale.getResult(), biasVal); + // Not a ranked tensor output + if (!adaptor.input().getType().dyn_cast()) + return op.emitError("Only ranked tensor types are supported"); - rewriter.replaceOp(op, {op6AddOp5Offset.getResult()}); + auto inputType = adaptor.input().getType().cast(); + if (inputType.getRank() > 4) + return op.emitError("Only up to 4D tensors are supported"); + + auto outType = getTypeConverter()->convertType(op.getType(0)); + + // Note: cudnn_enabled is not handled. + + // FIXME: Handle the None cases for the optional parameters. + if (adaptor.weight().getType().isa()) + return op.emitError("Unsupported None for weight"); + if (adaptor.bias().getType().isa()) + return op.emitError("Unsupported None for bias"); + + auto weightType = adaptor.weight().getType().cast(); + auto biasType = adaptor.bias().getType().cast(); + int64_t inputRank = inputType.getRank(); + Type elemTy = inputType.getElementType(); + + // Check if all the arguments meet the requirements. + SmallVector normalizedShapeSizesInt; + if (!matchPattern(op.normalized_shape(), + m_TorchConstantIntList(normalizedShapeSizesInt))) { + return rewriter.notifyMatchFailure(op, "Unimplemented normalized_shape not" + "constructed from ListConstruct"); + } + int64_t normalizedShapeRank = normalizedShapeSizesInt.size(); + if (weightType.getRank() != normalizedShapeRank || + biasType.getRank() != normalizedShapeRank || + inputRank < normalizedShapeRank || normalizedShapeRank < 1) + return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or" + "normalized shape not compatible"); + + // Check all the dimensions match the normalized_shape, only static shapes as + // of now + int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size(); + for (auto en : llvm::enumerate((normalizedShapeSizesInt))) { + int64_t index = en.index(); + int64_t value = en.value(); + if (inputType.getShape()[index + meanAndVarShapeRank] != value || + weightType.getShape()[index] != value || + biasType.getShape()[index] != value) + return op.emitError("mismatching contracting dimension"); + } + + // Helper for computing mean and variance. + auto computeSumAndReshape = [&](Value toReduce, RankedTensorType toReduceType, + Type outType, SmallVector outShape) { + Value sumDiv = toReduce; + SmallVector toReduceShape(toReduceType.getShape().begin(), + toReduceType.getShape().end()); + while (static_cast(toReduceShape.size()) != meanAndVarShapeRank) { + toReduceShape.back() = 1; + sumDiv = rewriter.create( + op.getLoc(), + RankedTensorType::get(toReduceShape, inputType.getElementType()), + sumDiv, rewriter.getI64IntegerAttr(toReduceShape.size() - 1)); + toReduceShape.pop_back(); + } + + return rewriter.create(op.getLoc(), outType, sumDiv, + rewriter.getI64ArrayAttr(outShape)); + }; + + // TOSA has integer Div so, compute reciprocal of element count to be used in + // mul. + int64_t elemCnt = 1; + for (auto i : normalizedShapeSizesInt) + elemCnt *= i; + + auto elemCntConst = + tosa::getConstTensor(rewriter, op.getOperation(), + {static_cast(elemCnt)}, {1}) + .getValue(); + Value elemCntRcp = rewriter.create( + op.getLoc(), elemCntConst.getType(), elemCntConst); + + // Broadcast type and shape for various intermediate values. + SmallVector bcastOutShape; + for (auto en : llvm::enumerate(inputType.getShape())) { + bcastOutShape.push_back( + static_cast(en.index()) >= meanAndVarShapeRank ? 1 + : en.value()); + } + auto bcastOutType = RankedTensorType::get(bcastOutShape, elemTy); + + // Compute mean. + Value sum = computeSumAndReshape(adaptor.input(), inputType, bcastOutType, + bcastOutShape); + Value meanVal = rewriter.create(op.getLoc(), bcastOutType, sum, + elemCntRcp, /*shift=*/0); + + // Compute variance. + Value squareSumSub = rewriter.create(op.getLoc(), inputType, + adaptor.input(), meanVal); + Value squareSum = rewriter.create(op.getLoc(), inputType, + squareSumSub, squareSumSub, 0); + + Value squareSumReduced = + computeSumAndReshape(squareSum, inputType, bcastOutType, bcastOutShape); + Value varianceVal = rewriter.create( + op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, /*shift=*/0); + + // Reshape weight and bias. + SmallVector weightAndBiasBcastShape; + for (auto en : llvm::enumerate(inputType.getShape())) { + weightAndBiasBcastShape.push_back( + static_cast(en.index()) < meanAndVarShapeRank ? 1 + : en.value()); + } + auto weightAndMeanBcastType = + RankedTensorType::get(weightAndBiasBcastShape, elemTy); + + Value weightVal = rewriter.create( + op.getLoc(), weightAndMeanBcastType, adaptor.weight(), + rewriter.getI64ArrayAttr(weightAndBiasBcastShape)); + + Value biasVal = rewriter.create( + op.getLoc(), weightAndMeanBcastType, adaptor.bias(), + rewriter.getI64ArrayAttr(weightAndBiasBcastShape)); + + double eps; + if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) + return op.emitError("eps must be a scalar constant"); + auto epsilonConst = + mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); + + // Compute layer norm. + auto layerNorm = + computeBatchNorm(op, rewriter, outType, adaptor.input(), varianceVal, + epsilonConst, meanVal, weightVal, biasVal); + + rewriter.replaceOp(op, {layerNorm, meanVal, varianceVal}); return success(); } @@ -1987,6 +2151,39 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPermuteOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a ranked tensor type + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError( + "Only ranked tensor types with static shapes are currently supported"); + + SmallVector dimListInt; + if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(dimListInt))) + return rewriter.notifyMatchFailure( + op, "Only constant dimensions are currently supported"); + + int64_t selfRank = selfType.getRank(); + for (auto &d : dimListInt) { + d = toPositiveDim(d, selfRank); + if (!isValidDim(d, selfRank)) + return op.emitError("Not all dims are valid"); + } + + auto transposeDimsConst = mlir::tosa::getConstTensor( + rewriter, op.getOperation(), dimListInt, {selfRank}); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.self(), + transposeDimsConst.getValue()); + + return success(); +} + template class ConvertAtenPoolingBaseOp : public OpConversionPattern { public: @@ -2429,7 +2626,9 @@ public: INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenReshapeOp); INSERT_ATENOP_PATTERN(AtenBatchNormOp); + INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenPermuteOp); #undef INSERT_ATENOP_PATTERN if (failed(applyPartialConversion(getOperation(), target, diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index d2aa5bb15..6905e48ef 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -544,3 +544,74 @@ func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10, %0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32> return %0 : !torch.vtensor<[10,3,?,4],f32> } + +// ----- + +// CHECK-LABEL: func @forward( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.float 5.000000e-01 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_8:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = "tosa.const"() {value = dense<1.200000e+01> : tensor<1xf32>} : () -> tensor<1xf32> +// CHECK: %[[VAL_11:.*]] = "tosa.reciprocal"(%[[VAL_10]]) : (tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) {axis = 1 : i64} : (tensor<5x2x1xf32>) -> tensor<5x1xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1xf32> +// CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) {axis = 1 : i64} : (tensor<5x2x1xf32>) -> tensor<5x1xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_26:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor} : () -> tensor +// CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_29:.*]] = "tosa.rsqrt"(%[[VAL_28]]) : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_30:.*]] = "tosa.mul"(%[[VAL_27]], %[[VAL_29]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_31:.*]] = "tosa.mul"(%[[VAL_30]], %[[VAL_24]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_32:.*]] = "tosa.add"(%[[VAL_31]], %[[VAL_25]]) : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> +// CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32> +// CHECK: } +func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> { + %float5.000000e-01 = torch.constant.float 5.000000e-01 + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int2, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2, %float5.000000e-01 : !torch.vtensor<[5,2,2,3],f32>, !torch.list, !torch.vtensor<[2,2,3],f32>, !torch.vtensor<[2,2,3],f32>, !torch.float -> !torch.vtensor<[5,2,2,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> + return %result0 : !torch.vtensor<[5,2,2,3],f32> +} + +// ----- + +// CHECK-LABEL: func @forward( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[VAL_7:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_6]]) : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> +// CHECK: } +func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %int0, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[3,4,2],f32>, !torch.list -> !torch.vtensor<[3,2,4],f32> + return %1 : !torch.vtensor<[3,2,4],f32> +}