* [tosa] Support for AtenNativeLayerNormOp

* [tosa] Support for AtenPermuteOp

Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>
pull/561/head snapshot-20220204.248
Anup Gangwar 2022-02-03 14:08:19 -08:00 committed by Yi Zhang
parent ccf546f14c
commit f9f97ea184
4 changed files with 355 additions and 56 deletions

View File

@ -114,6 +114,31 @@ def NativeLayerNormModule_basic(module, tu: TestUtils):
# ==============================================================================
class NativeLayerNormModule4D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([5, 2, 2, 3], torch.float32, True),
([2, 2, 3], torch.float32, True),
([2, 2, 3], torch.float32, True),
])
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)[0]
@register_test_case(module_factory=lambda: NativeLayerNormModule4D())
def NativeLayerNormModule4D_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
# ==============================================================================
class LayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -92,4 +92,8 @@ TOSA_PASS_SET = {
"SquareModule_basic",
"MaxPool2dStaticModule_basic",
"ResNet18StaticModule_basic",
"NativeLayerNormModule4D_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
}

View File

@ -1793,48 +1793,9 @@ LogicalResult ConvertAtenOp<AtenReshapeOp>::matchAndRewrite(
return success();
}
// Normalization ops perform elementwise ops of a single mean/stdev value
// against the feature map and because input is NCHW, the rank-1 value must be
// reshaped so it sits on the same dim as 'C'.
static LogicalResult reshapeToNormInputDim(Operation *op,
ConversionPatternRewriter &rewriter,
TypeConverter *converter,
Type outType, const Value toBcast,
Value &result) {
RankedTensorType toBcastType = toBcast.getType().dyn_cast<RankedTensorType>();
if (toBcastType.getRank() > 1)
op->emitError("Rank cannot be more than 1");
RankedTensorType outTensorType = outType.cast<RankedTensorType>();
SmallVector<int64_t> newShape = {toBcastType.getShape()[0]};
for (auto i = 2; i < outTensorType.getRank(); ++i)
newShape.push_back(1);
auto newType =
RankedTensorType::get(newShape, outTensorType.getElementType());
result = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), converter->convertType(newType), toBcast,
rewriter.getI64ArrayAttr(newShape));
return success();
}
// This lowering is based on the TensorFlow to TOSA lowering.
template <>
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
AtenBatchNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a ranked tensor output
if (!adaptor.input().getType().dyn_cast<RankedTensorType>())
return op.emitError("Only ranked tensor types are supported");
auto outType = getTypeConverter()->convertType(op.getType());
// FIXME: Handle training, momentum and cudnn_enabled
if (op.momentum().getType().isa<Torch::NoneType>())
op.emitError("Unsupported None for momentum");
Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter,
Type outType, Value input, Value variance, Value eps,
Value mean, Value weight, Value bias) {
// For PyTorch:
// scale = gamma = weight
// offset = beta = bias
@ -1858,11 +1819,76 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
// op5 = mul(op4, bscale)
// op6 = add(op5, boffset)
auto op1SubInputMean =
rewriter.create<tosa::SubOp>(op->getLoc(), outType, input, mean);
auto op2AddVarEpsilon = rewriter.create<tosa::AddOp>(
op->getLoc(), variance.getType(), variance, eps);
auto op3RsqrtOp2 = rewriter.create<tosa::RsqrtOp>(
op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult());
auto op4MulOp1Op3 = rewriter.create<tosa::MulOp>(op->getLoc(), outType,
op1SubInputMean.getResult(),
op3RsqrtOp2.getResult(), 0);
auto op5MulOp4Scale = rewriter.create<tosa::MulOp>(
op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, 0);
return rewriter
.create<tosa::AddOp>(op->getLoc(), outType, op5MulOp4Scale.getResult(),
bias)
.getResult();
}
// This lowering is based on the TensorFlow to TOSA lowering.
template <>
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
AtenBatchNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a ranked tensor output
if (!adaptor.input().getType().dyn_cast<RankedTensorType>())
return op.emitError("Only ranked tensor types are supported");
auto outType = getTypeConverter()->convertType(op.getType());
// Note: cudnn_enabled is not handled.
// FIXME: Handle training and momentum.
if (op.momentum().getType().isa<Torch::NoneType>())
op.emitError("Unsupported None for momentum");
auto meanType = adaptor.running_mean().getType().dyn_cast<TensorType>();
auto varianceType = adaptor.running_var().getType().dyn_cast<TensorType>();
if (!varianceType || !meanType)
return op.emitError("Only ranked tensor types are supported");
// Normalization ops perform elementwise ops of a single mean/stdev value
// against the feature map and because input is NCHW, the rank-1 value must be
// reshaped so it sits on the same dim as 'C'.
auto reshapeToNormInputDim = [&](Operation *op,
ConversionPatternRewriter &rewriter,
TypeConverter *converter, Type outType,
const Value toBcast, Value &result) {
RankedTensorType toBcastType =
toBcast.getType().dyn_cast<RankedTensorType>();
if (toBcastType.getRank() > 1)
op->emitError("Rank cannot be more than 1");
RankedTensorType outTensorType = outType.cast<RankedTensorType>();
SmallVector<int64_t> newShape = {toBcastType.getShape()[0]};
for (auto i = 2; i < outTensorType.getRank(); ++i)
newShape.push_back(1);
auto newType =
RankedTensorType::get(newShape, outTensorType.getElementType());
result = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), newType, toBcast, rewriter.getI64ArrayAttr(newShape));
return success();
};
Value meanVal, varianceVal, weightVal, biasVal;
assert(meanType.getNumElements() != 0 && varianceType.getNumElements() != 0);
if (failed(reshapeToNormInputDim(op.getOperation(), rewriter,
@ -1892,26 +1918,164 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
auto epsilonConst =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps);
auto op1SubInputMean = rewriter.create<tosa::SubOp>(op->getLoc(), outType,
adaptor.input(), meanVal);
auto batchNorm =
computeBatchNorm(op, rewriter, outType, adaptor.input(), varianceVal,
epsilonConst, meanVal, weightVal, biasVal);
auto op2AddVarEpsilon = rewriter.create<tosa::AddOp>(
op->getLoc(), varianceVal.getType(), varianceVal, epsilonConst);
rewriter.replaceOp(op, {batchNorm});
auto op3RsqrtOp2 = rewriter.create<tosa::RsqrtOp>(
op->getLoc(), varianceVal.getType(), op2AddVarEpsilon.getResult());
return success();
}
auto op4MulOp1Op3 = rewriter.create<tosa::MulOp>(op->getLoc(), outType,
op1SubInputMean.getResult(),
op3RsqrtOp2.getResult(), 0);
// This lowering is loosely based on Torch to LinAlg lowering.
template <>
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
AtenNativeLayerNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto op5MulOp4Scale = rewriter.create<tosa::MulOp>(
op->getLoc(), outType, op4MulOp1Op3.getResult(), weightVal, 0);
// The key difference from BatchNorm is that a specified set of dims
// (normalized_shape) are chosen to compute the mean and variance from input.
// Where as in BatchNorm the mean and variance are operands. tosa::ReduceSumOp
// is used to sum up the these dims for mean and for variance. The results
// eventually being reshaped for broadcasting.
auto op6AddOp5Offset = rewriter.create<tosa::AddOp>(
op->getLoc(), outType, op5MulOp4Scale.getResult(), biasVal);
// Not a ranked tensor output
if (!adaptor.input().getType().dyn_cast<RankedTensorType>())
return op.emitError("Only ranked tensor types are supported");
rewriter.replaceOp(op, {op6AddOp5Offset.getResult()});
auto inputType = adaptor.input().getType().cast<RankedTensorType>();
if (inputType.getRank() > 4)
return op.emitError("Only up to 4D tensors are supported");
auto outType = getTypeConverter()->convertType(op.getType(0));
// Note: cudnn_enabled is not handled.
// FIXME: Handle the None cases for the optional parameters.
if (adaptor.weight().getType().isa<Torch::NoneType>())
return op.emitError("Unsupported None for weight");
if (adaptor.bias().getType().isa<Torch::NoneType>())
return op.emitError("Unsupported None for bias");
auto weightType = adaptor.weight().getType().cast<RankedTensorType>();
auto biasType = adaptor.bias().getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
Type elemTy = inputType.getElementType();
// Check if all the arguments meet the requirements.
SmallVector<int64_t> normalizedShapeSizesInt;
if (!matchPattern(op.normalized_shape(),
m_TorchConstantIntList(normalizedShapeSizesInt))) {
return rewriter.notifyMatchFailure(op, "Unimplemented normalized_shape not"
"constructed from ListConstruct");
}
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
if (weightType.getRank() != normalizedShapeRank ||
biasType.getRank() != normalizedShapeRank ||
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
"normalized shape not compatible");
// Check all the dimensions match the normalized_shape, only static shapes as
// of now
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
for (auto en : llvm::enumerate((normalizedShapeSizesInt))) {
int64_t index = en.index();
int64_t value = en.value();
if (inputType.getShape()[index + meanAndVarShapeRank] != value ||
weightType.getShape()[index] != value ||
biasType.getShape()[index] != value)
return op.emitError("mismatching contracting dimension");
}
// Helper for computing mean and variance.
auto computeSumAndReshape = [&](Value toReduce, RankedTensorType toReduceType,
Type outType, SmallVector<int64_t> outShape) {
Value sumDiv = toReduce;
SmallVector<int64_t> toReduceShape(toReduceType.getShape().begin(),
toReduceType.getShape().end());
while (static_cast<int64_t>(toReduceShape.size()) != meanAndVarShapeRank) {
toReduceShape.back() = 1;
sumDiv = rewriter.create<tosa::ReduceSumOp>(
op.getLoc(),
RankedTensorType::get(toReduceShape, inputType.getElementType()),
sumDiv, rewriter.getI64IntegerAttr(toReduceShape.size() - 1));
toReduceShape.pop_back();
}
return rewriter.create<tosa::ReshapeOp>(op.getLoc(), outType, sumDiv,
rewriter.getI64ArrayAttr(outShape));
};
// TOSA has integer Div so, compute reciprocal of element count to be used in
// mul.
int64_t elemCnt = 1;
for (auto i : normalizedShapeSizesInt)
elemCnt *= i;
auto elemCntConst =
tosa::getConstTensor<float>(rewriter, op.getOperation(),
{static_cast<float>(elemCnt)}, {1})
.getValue();
Value elemCntRcp = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), elemCntConst.getType(), elemCntConst);
// Broadcast type and shape for various intermediate values.
SmallVector<int64_t> bcastOutShape;
for (auto en : llvm::enumerate(inputType.getShape())) {
bcastOutShape.push_back(
static_cast<int64_t>(en.index()) >= meanAndVarShapeRank ? 1
: en.value());
}
auto bcastOutType = RankedTensorType::get(bcastOutShape, elemTy);
// Compute mean.
Value sum = computeSumAndReshape(adaptor.input(), inputType, bcastOutType,
bcastOutShape);
Value meanVal = rewriter.create<tosa::MulOp>(op.getLoc(), bcastOutType, sum,
elemCntRcp, /*shift=*/0);
// Compute variance.
Value squareSumSub = rewriter.create<tosa::SubOp>(op.getLoc(), inputType,
adaptor.input(), meanVal);
Value squareSum = rewriter.create<tosa::MulOp>(op.getLoc(), inputType,
squareSumSub, squareSumSub, 0);
Value squareSumReduced =
computeSumAndReshape(squareSum, inputType, bcastOutType, bcastOutShape);
Value varianceVal = rewriter.create<tosa::MulOp>(
op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, /*shift=*/0);
// Reshape weight and bias.
SmallVector<int64_t> weightAndBiasBcastShape;
for (auto en : llvm::enumerate(inputType.getShape())) {
weightAndBiasBcastShape.push_back(
static_cast<int64_t>(en.index()) < meanAndVarShapeRank ? 1
: en.value());
}
auto weightAndMeanBcastType =
RankedTensorType::get(weightAndBiasBcastShape, elemTy);
Value weightVal = rewriter.create<tosa::ReshapeOp>(
op.getLoc(), weightAndMeanBcastType, adaptor.weight(),
rewriter.getI64ArrayAttr(weightAndBiasBcastShape));
Value biasVal = rewriter.create<tosa::ReshapeOp>(
op.getLoc(), weightAndMeanBcastType, adaptor.bias(),
rewriter.getI64ArrayAttr(weightAndBiasBcastShape));
double eps;
if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps)))
return op.emitError("eps must be a scalar constant");
auto epsilonConst =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps);
// Compute layer norm.
auto layerNorm =
computeBatchNorm(op, rewriter, outType, adaptor.input(), varianceVal,
epsilonConst, meanVal, weightVal, biasVal);
rewriter.replaceOp(op, {layerNorm, meanVal, varianceVal});
return success();
}
@ -1987,6 +2151,39 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
AtenPermuteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a ranked tensor type
auto selfType = adaptor.self().getType().dyn_cast<RankedTensorType>();
if (!selfType)
return op.emitError(
"Only ranked tensor types with static shapes are currently supported");
SmallVector<int64_t> dimListInt;
if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(dimListInt)))
return rewriter.notifyMatchFailure(
op, "Only constant dimensions are currently supported");
int64_t selfRank = selfType.getRank();
for (auto &d : dimListInt) {
d = toPositiveDim(d, selfRank);
if (!isValidDim(d, selfRank))
return op.emitError("Not all dims are valid");
}
auto transposeDimsConst = mlir::tosa::getConstTensor<int64_t>(
rewriter, op.getOperation(), dimListInt, {selfRank});
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
transposeDimsConst.getValue());
return success();
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public:
@ -2429,7 +2626,9 @@ public:
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReshapeOp);
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,

View File

@ -544,3 +544,74 @@ func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32>
return %0 : !torch.vtensor<[10,3,?,4],f32>
}
// -----
// CHECK-LABEL: func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> {
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32>
// CHECK: %[[VAL_6:.*]] = torch.constant.float 5.000000e-01
// CHECK: %[[VAL_7:.*]] = torch.constant.int 3
// CHECK: %[[VAL_8:.*]] = torch.constant.int 2
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() {value = dense<1.200000e+01> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK: %[[VAL_11:.*]] = "tosa.reciprocal"(%[[VAL_10]]) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
// CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1xf32>
// CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) {axis = 1 : i64} : (tensor<5x2x1xf32>) -> tensor<5x1xf32>
// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
// CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1xf32>
// CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) {axis = 1 : i64} : (tensor<5x2x1xf32>) -> tensor<5x1xf32>
// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
// CHECK: %[[VAL_26:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor<f32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_29:.*]] = "tosa.rsqrt"(%[[VAL_28]]) : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
// CHECK: %[[VAL_30:.*]] = "tosa.mul"(%[[VAL_27]], %[[VAL_29]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_31:.*]] = "tosa.mul"(%[[VAL_30]], %[[VAL_24]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_32:.*]] = "tosa.add"(%[[VAL_31]], %[[VAL_25]]) : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32>
// CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32>
// CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32>
// CHECK: }
func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> {
%float5.000000e-01 = torch.constant.float 5.000000e-01
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int2, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2, %float5.000000e-01 : !torch.vtensor<[5,2,2,3],f32>, !torch.list<!torch.int>, !torch.vtensor<[2,2,3],f32>, !torch.vtensor<[2,2,3],f32>, !torch.float -> !torch.vtensor<[5,2,2,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
return %result0 : !torch.vtensor<[5,2,2,3],f32>
}
// -----
// CHECK-LABEL: func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
// CHECK: %[[VAL_7:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_6]]) : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
// CHECK: }
func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %int0, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[3,4,2],f32>, !torch.list<!torch.int> -> !torch.vtensor<[3,2,4],f32>
return %1 : !torch.vtensor<[3,2,4],f32>
}