mirror of https://github.com/llvm/torch-mlir
* [tosa] Support for AtenNativeLayerNormOp
* [tosa] Support for AtenPermuteOp Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>pull/561/head snapshot-20220204.248
parent
ccf546f14c
commit
f9f97ea184
|
@ -114,6 +114,31 @@ def NativeLayerNormModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeLayerNormModule4D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 2, 2, 3], torch.float32, True),
|
||||
([2, 2, 3], torch.float32, True),
|
||||
([2, 2, 3], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormModule4D())
|
||||
def NativeLayerNormModule4D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LayerNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -92,4 +92,8 @@ TOSA_PASS_SET = {
|
|||
"SquareModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"NativeLayerNormModule4D_basic",
|
||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
}
|
||||
|
|
|
@ -1793,48 +1793,9 @@ LogicalResult ConvertAtenOp<AtenReshapeOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Normalization ops perform elementwise ops of a single mean/stdev value
|
||||
// against the feature map and because input is NCHW, the rank-1 value must be
|
||||
// reshaped so it sits on the same dim as 'C'.
|
||||
static LogicalResult reshapeToNormInputDim(Operation *op,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *converter,
|
||||
Type outType, const Value toBcast,
|
||||
Value &result) {
|
||||
RankedTensorType toBcastType = toBcast.getType().dyn_cast<RankedTensorType>();
|
||||
if (toBcastType.getRank() > 1)
|
||||
op->emitError("Rank cannot be more than 1");
|
||||
|
||||
RankedTensorType outTensorType = outType.cast<RankedTensorType>();
|
||||
SmallVector<int64_t> newShape = {toBcastType.getShape()[0]};
|
||||
for (auto i = 2; i < outTensorType.getRank(); ++i)
|
||||
newShape.push_back(1);
|
||||
auto newType =
|
||||
RankedTensorType::get(newShape, outTensorType.getElementType());
|
||||
|
||||
result = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), converter->convertType(newType), toBcast,
|
||||
rewriter.getI64ArrayAttr(newShape));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// This lowering is based on the TensorFlow to TOSA lowering.
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||
AtenBatchNormOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a ranked tensor output
|
||||
if (!adaptor.input().getType().dyn_cast<RankedTensorType>())
|
||||
return op.emitError("Only ranked tensor types are supported");
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
// FIXME: Handle training, momentum and cudnn_enabled
|
||||
if (op.momentum().getType().isa<Torch::NoneType>())
|
||||
op.emitError("Unsupported None for momentum");
|
||||
|
||||
Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter,
|
||||
Type outType, Value input, Value variance, Value eps,
|
||||
Value mean, Value weight, Value bias) {
|
||||
// For PyTorch:
|
||||
// scale = gamma = weight
|
||||
// offset = beta = bias
|
||||
|
@ -1858,11 +1819,76 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
// op5 = mul(op4, bscale)
|
||||
// op6 = add(op5, boffset)
|
||||
|
||||
auto op1SubInputMean =
|
||||
rewriter.create<tosa::SubOp>(op->getLoc(), outType, input, mean);
|
||||
|
||||
auto op2AddVarEpsilon = rewriter.create<tosa::AddOp>(
|
||||
op->getLoc(), variance.getType(), variance, eps);
|
||||
|
||||
auto op3RsqrtOp2 = rewriter.create<tosa::RsqrtOp>(
|
||||
op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult());
|
||||
|
||||
auto op4MulOp1Op3 = rewriter.create<tosa::MulOp>(op->getLoc(), outType,
|
||||
op1SubInputMean.getResult(),
|
||||
op3RsqrtOp2.getResult(), 0);
|
||||
|
||||
auto op5MulOp4Scale = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, 0);
|
||||
|
||||
return rewriter
|
||||
.create<tosa::AddOp>(op->getLoc(), outType, op5MulOp4Scale.getResult(),
|
||||
bias)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// This lowering is based on the TensorFlow to TOSA lowering.
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||
AtenBatchNormOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a ranked tensor output
|
||||
if (!adaptor.input().getType().dyn_cast<RankedTensorType>())
|
||||
return op.emitError("Only ranked tensor types are supported");
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
// Note: cudnn_enabled is not handled.
|
||||
|
||||
// FIXME: Handle training and momentum.
|
||||
if (op.momentum().getType().isa<Torch::NoneType>())
|
||||
op.emitError("Unsupported None for momentum");
|
||||
|
||||
auto meanType = adaptor.running_mean().getType().dyn_cast<TensorType>();
|
||||
auto varianceType = adaptor.running_var().getType().dyn_cast<TensorType>();
|
||||
if (!varianceType || !meanType)
|
||||
return op.emitError("Only ranked tensor types are supported");
|
||||
|
||||
// Normalization ops perform elementwise ops of a single mean/stdev value
|
||||
// against the feature map and because input is NCHW, the rank-1 value must be
|
||||
// reshaped so it sits on the same dim as 'C'.
|
||||
auto reshapeToNormInputDim = [&](Operation *op,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *converter, Type outType,
|
||||
const Value toBcast, Value &result) {
|
||||
RankedTensorType toBcastType =
|
||||
toBcast.getType().dyn_cast<RankedTensorType>();
|
||||
if (toBcastType.getRank() > 1)
|
||||
op->emitError("Rank cannot be more than 1");
|
||||
|
||||
RankedTensorType outTensorType = outType.cast<RankedTensorType>();
|
||||
SmallVector<int64_t> newShape = {toBcastType.getShape()[0]};
|
||||
for (auto i = 2; i < outTensorType.getRank(); ++i)
|
||||
newShape.push_back(1);
|
||||
auto newType =
|
||||
RankedTensorType::get(newShape, outTensorType.getElementType());
|
||||
|
||||
result = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), newType, toBcast, rewriter.getI64ArrayAttr(newShape));
|
||||
|
||||
return success();
|
||||
};
|
||||
|
||||
Value meanVal, varianceVal, weightVal, biasVal;
|
||||
assert(meanType.getNumElements() != 0 && varianceType.getNumElements() != 0);
|
||||
if (failed(reshapeToNormInputDim(op.getOperation(), rewriter,
|
||||
|
@ -1892,26 +1918,164 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
auto epsilonConst =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps);
|
||||
|
||||
auto op1SubInputMean = rewriter.create<tosa::SubOp>(op->getLoc(), outType,
|
||||
adaptor.input(), meanVal);
|
||||
auto batchNorm =
|
||||
computeBatchNorm(op, rewriter, outType, adaptor.input(), varianceVal,
|
||||
epsilonConst, meanVal, weightVal, biasVal);
|
||||
|
||||
auto op2AddVarEpsilon = rewriter.create<tosa::AddOp>(
|
||||
op->getLoc(), varianceVal.getType(), varianceVal, epsilonConst);
|
||||
rewriter.replaceOp(op, {batchNorm});
|
||||
|
||||
auto op3RsqrtOp2 = rewriter.create<tosa::RsqrtOp>(
|
||||
op->getLoc(), varianceVal.getType(), op2AddVarEpsilon.getResult());
|
||||
return success();
|
||||
}
|
||||
|
||||
auto op4MulOp1Op3 = rewriter.create<tosa::MulOp>(op->getLoc(), outType,
|
||||
op1SubInputMean.getResult(),
|
||||
op3RsqrtOp2.getResult(), 0);
|
||||
// This lowering is loosely based on Torch to LinAlg lowering.
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||
AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto op5MulOp4Scale = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), outType, op4MulOp1Op3.getResult(), weightVal, 0);
|
||||
// The key difference from BatchNorm is that a specified set of dims
|
||||
// (normalized_shape) are chosen to compute the mean and variance from input.
|
||||
// Where as in BatchNorm the mean and variance are operands. tosa::ReduceSumOp
|
||||
// is used to sum up the these dims for mean and for variance. The results
|
||||
// eventually being reshaped for broadcasting.
|
||||
|
||||
auto op6AddOp5Offset = rewriter.create<tosa::AddOp>(
|
||||
op->getLoc(), outType, op5MulOp4Scale.getResult(), biasVal);
|
||||
// Not a ranked tensor output
|
||||
if (!adaptor.input().getType().dyn_cast<RankedTensorType>())
|
||||
return op.emitError("Only ranked tensor types are supported");
|
||||
|
||||
rewriter.replaceOp(op, {op6AddOp5Offset.getResult()});
|
||||
auto inputType = adaptor.input().getType().cast<RankedTensorType>();
|
||||
if (inputType.getRank() > 4)
|
||||
return op.emitError("Only up to 4D tensors are supported");
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType(0));
|
||||
|
||||
// Note: cudnn_enabled is not handled.
|
||||
|
||||
// FIXME: Handle the None cases for the optional parameters.
|
||||
if (adaptor.weight().getType().isa<Torch::NoneType>())
|
||||
return op.emitError("Unsupported None for weight");
|
||||
if (adaptor.bias().getType().isa<Torch::NoneType>())
|
||||
return op.emitError("Unsupported None for bias");
|
||||
|
||||
auto weightType = adaptor.weight().getType().cast<RankedTensorType>();
|
||||
auto biasType = adaptor.bias().getType().cast<RankedTensorType>();
|
||||
int64_t inputRank = inputType.getRank();
|
||||
Type elemTy = inputType.getElementType();
|
||||
|
||||
// Check if all the arguments meet the requirements.
|
||||
SmallVector<int64_t> normalizedShapeSizesInt;
|
||||
if (!matchPattern(op.normalized_shape(),
|
||||
m_TorchConstantIntList(normalizedShapeSizesInt))) {
|
||||
return rewriter.notifyMatchFailure(op, "Unimplemented normalized_shape not"
|
||||
"constructed from ListConstruct");
|
||||
}
|
||||
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
|
||||
if (weightType.getRank() != normalizedShapeRank ||
|
||||
biasType.getRank() != normalizedShapeRank ||
|
||||
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
|
||||
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
|
||||
"normalized shape not compatible");
|
||||
|
||||
// Check all the dimensions match the normalized_shape, only static shapes as
|
||||
// of now
|
||||
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
|
||||
for (auto en : llvm::enumerate((normalizedShapeSizesInt))) {
|
||||
int64_t index = en.index();
|
||||
int64_t value = en.value();
|
||||
if (inputType.getShape()[index + meanAndVarShapeRank] != value ||
|
||||
weightType.getShape()[index] != value ||
|
||||
biasType.getShape()[index] != value)
|
||||
return op.emitError("mismatching contracting dimension");
|
||||
}
|
||||
|
||||
// Helper for computing mean and variance.
|
||||
auto computeSumAndReshape = [&](Value toReduce, RankedTensorType toReduceType,
|
||||
Type outType, SmallVector<int64_t> outShape) {
|
||||
Value sumDiv = toReduce;
|
||||
SmallVector<int64_t> toReduceShape(toReduceType.getShape().begin(),
|
||||
toReduceType.getShape().end());
|
||||
while (static_cast<int64_t>(toReduceShape.size()) != meanAndVarShapeRank) {
|
||||
toReduceShape.back() = 1;
|
||||
sumDiv = rewriter.create<tosa::ReduceSumOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(toReduceShape, inputType.getElementType()),
|
||||
sumDiv, rewriter.getI64IntegerAttr(toReduceShape.size() - 1));
|
||||
toReduceShape.pop_back();
|
||||
}
|
||||
|
||||
return rewriter.create<tosa::ReshapeOp>(op.getLoc(), outType, sumDiv,
|
||||
rewriter.getI64ArrayAttr(outShape));
|
||||
};
|
||||
|
||||
// TOSA has integer Div so, compute reciprocal of element count to be used in
|
||||
// mul.
|
||||
int64_t elemCnt = 1;
|
||||
for (auto i : normalizedShapeSizesInt)
|
||||
elemCnt *= i;
|
||||
|
||||
auto elemCntConst =
|
||||
tosa::getConstTensor<float>(rewriter, op.getOperation(),
|
||||
{static_cast<float>(elemCnt)}, {1})
|
||||
.getValue();
|
||||
Value elemCntRcp = rewriter.create<tosa::ReciprocalOp>(
|
||||
op.getLoc(), elemCntConst.getType(), elemCntConst);
|
||||
|
||||
// Broadcast type and shape for various intermediate values.
|
||||
SmallVector<int64_t> bcastOutShape;
|
||||
for (auto en : llvm::enumerate(inputType.getShape())) {
|
||||
bcastOutShape.push_back(
|
||||
static_cast<int64_t>(en.index()) >= meanAndVarShapeRank ? 1
|
||||
: en.value());
|
||||
}
|
||||
auto bcastOutType = RankedTensorType::get(bcastOutShape, elemTy);
|
||||
|
||||
// Compute mean.
|
||||
Value sum = computeSumAndReshape(adaptor.input(), inputType, bcastOutType,
|
||||
bcastOutShape);
|
||||
Value meanVal = rewriter.create<tosa::MulOp>(op.getLoc(), bcastOutType, sum,
|
||||
elemCntRcp, /*shift=*/0);
|
||||
|
||||
// Compute variance.
|
||||
Value squareSumSub = rewriter.create<tosa::SubOp>(op.getLoc(), inputType,
|
||||
adaptor.input(), meanVal);
|
||||
Value squareSum = rewriter.create<tosa::MulOp>(op.getLoc(), inputType,
|
||||
squareSumSub, squareSumSub, 0);
|
||||
|
||||
Value squareSumReduced =
|
||||
computeSumAndReshape(squareSum, inputType, bcastOutType, bcastOutShape);
|
||||
Value varianceVal = rewriter.create<tosa::MulOp>(
|
||||
op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, /*shift=*/0);
|
||||
|
||||
// Reshape weight and bias.
|
||||
SmallVector<int64_t> weightAndBiasBcastShape;
|
||||
for (auto en : llvm::enumerate(inputType.getShape())) {
|
||||
weightAndBiasBcastShape.push_back(
|
||||
static_cast<int64_t>(en.index()) < meanAndVarShapeRank ? 1
|
||||
: en.value());
|
||||
}
|
||||
auto weightAndMeanBcastType =
|
||||
RankedTensorType::get(weightAndBiasBcastShape, elemTy);
|
||||
|
||||
Value weightVal = rewriter.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), weightAndMeanBcastType, adaptor.weight(),
|
||||
rewriter.getI64ArrayAttr(weightAndBiasBcastShape));
|
||||
|
||||
Value biasVal = rewriter.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), weightAndMeanBcastType, adaptor.bias(),
|
||||
rewriter.getI64ArrayAttr(weightAndBiasBcastShape));
|
||||
|
||||
double eps;
|
||||
if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps)))
|
||||
return op.emitError("eps must be a scalar constant");
|
||||
auto epsilonConst =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps);
|
||||
|
||||
// Compute layer norm.
|
||||
auto layerNorm =
|
||||
computeBatchNorm(op, rewriter, outType, adaptor.input(), varianceVal,
|
||||
epsilonConst, meanVal, weightVal, biasVal);
|
||||
|
||||
rewriter.replaceOp(op, {layerNorm, meanVal, varianceVal});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1987,6 +2151,39 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
||||
AtenPermuteOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a ranked tensor type
|
||||
auto selfType = adaptor.self().getType().dyn_cast<RankedTensorType>();
|
||||
if (!selfType)
|
||||
return op.emitError(
|
||||
"Only ranked tensor types with static shapes are currently supported");
|
||||
|
||||
SmallVector<int64_t> dimListInt;
|
||||
if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(dimListInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only constant dimensions are currently supported");
|
||||
|
||||
int64_t selfRank = selfType.getRank();
|
||||
for (auto &d : dimListInt) {
|
||||
d = toPositiveDim(d, selfRank);
|
||||
if (!isValidDim(d, selfRank))
|
||||
return op.emitError("Not all dims are valid");
|
||||
}
|
||||
|
||||
auto transposeDimsConst = mlir::tosa::getConstTensor<int64_t>(
|
||||
rewriter, op.getOperation(), dimListInt, {selfRank});
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
||||
transposeDimsConst.getValue());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
|
@ -2429,7 +2626,9 @@ public:
|
|||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReshapeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
|
@ -544,3 +544,74 @@ func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,
|
|||
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32>
|
||||
return %0 : !torch.vtensor<[10,3,?,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @forward(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> {
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: %[[VAL_7:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_8:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[VAL_10:.*]] = "tosa.const"() {value = dense<1.200000e+01> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = "tosa.reciprocal"(%[[VAL_10]]) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
// CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) {axis = 1 : i64} : (tensor<5x2x1xf32>) -> tensor<5x1xf32>
|
||||
// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1xf32>) -> tensor<5x1x1x1xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
|
||||
// CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
|
||||
// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32>
|
||||
// CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) {axis = 1 : i64} : (tensor<5x2x1xf32>) -> tensor<5x1xf32>
|
||||
// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1xf32>) -> tensor<5x1x1x1xf32>
|
||||
// CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
|
||||
// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
|
||||
// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
|
||||
// CHECK: %[[VAL_26:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
|
||||
// CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor<f32>) -> tensor<5x1x1x1xf32>
|
||||
// CHECK: %[[VAL_29:.*]] = "tosa.rsqrt"(%[[VAL_28]]) : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
|
||||
// CHECK: %[[VAL_30:.*]] = "tosa.mul"(%[[VAL_27]], %[[VAL_29]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
|
||||
// CHECK: %[[VAL_31:.*]] = "tosa.mul"(%[[VAL_30]], %[[VAL_24]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32>
|
||||
// CHECK: %[[VAL_32:.*]] = "tosa.add"(%[[VAL_31]], %[[VAL_25]]) : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32>
|
||||
// CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32>
|
||||
// CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32>
|
||||
// CHECK: }
|
||||
func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> {
|
||||
%float5.000000e-01 = torch.constant.float 5.000000e-01
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prim.ListConstruct %int2, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2, %float5.000000e-01 : !torch.vtensor<[5,2,2,3],f32>, !torch.list<!torch.int>, !torch.vtensor<[2,2,3],f32>, !torch.vtensor<[2,2,3],f32>, !torch.float -> !torch.vtensor<[5,2,2,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>
|
||||
return %result0 : !torch.vtensor<[5,2,2,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @forward(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi64>} : () -> tensor<3xi64>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_6]]) : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
|
||||
// CHECK: }
|
||||
func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.prim.ListConstruct %int0, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[3,4,2],f32>, !torch.list<!torch.int> -> !torch.vtensor<[3,2,4],f32>
|
||||
return %1 : !torch.vtensor<[3,2,4],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue