From 5d9a15263a1c43bd62d3f2cd2366b72b186edbb5 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sat, 29 Jan 2022 12:10:50 -0500 Subject: [PATCH] [TORCH] Add aten.std e2e support --- e2e_testing/torchscript/basic.py | 98 +++++++++++++++++ e2e_testing/torchscript/xfail_sets.py | 1 + .../Dialect/Torch/IR/GeneratedAtenOps.td | 58 ++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 102 +++++++++++++++++- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 5 +- .../jit_ir/build_tools/torch_ods_gen.py | 3 + test/Dialect/Torch/decompose-complex-ops.mlir | 100 +++++++++++++++++ 7 files changed, 362 insertions(+), 5 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index ec33379e7..ec97414a5 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -1006,6 +1006,7 @@ class TModuleRank2(torch.nn.Module): def TModuleRank2_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) +# ============================================================================== class TModuleRank1(torch.nn.Module): def __init__(self): @@ -1023,6 +1024,7 @@ class TModuleRank1(torch.nn.Module): def TModuleRank1_basic(module, tu: TestUtils): module.forward(tu.rand(3)) +# ============================================================================== class TModuleRank0(torch.nn.Module): def __init__(self): @@ -1040,6 +1042,8 @@ class TModuleRank0(torch.nn.Module): def TModuleRank0_basic(module, tu: TestUtils): module.forward(torch.tensor(7, dtype=torch.float32)) +# ============================================================================== + class TensorLiteralModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1057,6 +1061,7 @@ class TensorLiteralModule(torch.nn.Module): def TensorLiteralModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class TensorOpaqueLiteralModule(torch.nn.Module): def __init__(self): @@ -1075,6 +1080,8 @@ class TensorOpaqueLiteralModule(torch.nn.Module): def TensorOpaqueLiteralModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== + class ReturnTwoTensorF32I64(torch.nn.Module): def __init__(self): super().__init__() @@ -1092,6 +1099,7 @@ class ReturnTwoTensorF32I64(torch.nn.Module): def ReturnTwoTensorF32I64_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), torch.randint(5, (2, 3))) +# ============================================================================== class IndexTensorModule(torch.nn.Module): def __init__(self): @@ -1109,3 +1117,93 @@ class IndexTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorModule()) def IndexTensorModule_basic(module, tu: TestUtils): module.forward(tu.rand(5), torch.randint(4, (2, 3))) + +# ============================================================================== + +class SquareModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.square(x) + +@register_test_case(module_factory=lambda: SquareModule()) +def SquareModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + +class VarUnbiasedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, unbiased=True) + +@register_test_case(module_factory=lambda: VarUnbiasedModule()) +def VarUnbiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + +class VarBiasedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, unbiased=False) + +@register_test_case(module_factory=lambda: VarBiasedModule()) +def VarBiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + +class StdUnbiasedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, unbiased=True) + +@register_test_case(module_factory=lambda: StdUnbiasedModule()) +def StdUnbiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + +class StdBiasedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, unbiased=False) + +@register_test_case(module_factory=lambda: StdBiasedModule()) +def StdBiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 7af9c5809..482c82038 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -89,4 +89,5 @@ TOSA_PASS_SET = { "FlattenStaticModule_basic", "FlattenRank0Module_basic", "ElementwiseFlattenBroadcastModule_basic", + "SquareModule_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 1014eca84..3d4b0cf00 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1172,6 +1172,34 @@ def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [ let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($threshold)) `,` qualified(type($value)) `->` qualified(type($result))"; } +def Torch_AtenSquareOp : Torch_Op<"aten.square", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::square : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))"; +} + +def Torch_AtenSquare_Op : Torch_Op<"aten.square_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::square_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))"; +} + def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ AllowsTypeRefinement, HasValueSemantics @@ -1759,6 +1787,36 @@ def Torch_AtenMeanOp : Torch_Op<"aten.mean", [ let assemblyFormat = "$self `,` $dtype attr-dict `:` qualified(type($self)) `,` qualified(type($dtype)) `->` qualified(type($result))"; } +def Torch_AtenStdOp : Torch_Op<"aten.std", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::std : (Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_BoolType:$unbiased + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $unbiased attr-dict `:` qualified(type($self)) `,` qualified(type($unbiased)) `->` qualified(type($result))"; +} + +def Torch_AtenVarOp : Torch_Op<"aten.var", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::var : (Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_BoolType:$unbiased + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $unbiased attr-dict `:` qualified(type($self)) `,` qualified(type($unbiased)) `->` qualified(type($result))"; +} + def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 90b4f8e27..f06bb87f8 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -403,7 +403,7 @@ public: }; } // namespace -// Decompose torch.matmul into: torch.mm and torch.bmm according to ranks. +// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks. namespace { class DecomposeAtenMatmulOp : public OpRewritePattern { public: @@ -459,7 +459,7 @@ public: }; } // namespace -// Decompose torch.expand into torch.broadcast_to op. +// Decompose aten.expand into aten.broadcast_to op. namespace { class DecomposeAtenExpandOp : public OpRewritePattern { public: @@ -479,7 +479,7 @@ public: }; } // namespace -// Decompose torch.addmm into torch.mm and torch.add.Tensor op. +// Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { public: @@ -519,7 +519,7 @@ public: }; } // namespace -// Decompose torch.mean into: sum(x)/div(numTensorElements). +// Decompose aten.mean into: sum(x)/div(numTensorElements). namespace { class DecomposeAtenMeanOp : public OpRewritePattern { public: @@ -539,6 +539,94 @@ public: }; } // namespace +namespace { +class DecomposeAtenSquareOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSquareOp op, + PatternRewriter &rewriter) const override { + Value self = op.self(); + rewriter.replaceOpWithNewOp(op, op.getType(), self, self); + return success(); + } +}; +} // namespace + +// Decompose aten.var into: sum(square(x - mean))/(numTensorElements-1) +// for unbiased and mean(square(x - mean)) for biased case. +namespace { +class DecomposeAtenVarOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenVarOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.self(); + BaseTensorType inputTensorTy = self.getType().cast(); + if (!inputTensorTy.hasDtype() || + !inputTensorTy.getDtype().isa()) { + return rewriter.notifyMatchFailure(op, + "Only aten.var support floating type"); + } + BaseTensorType rank0FloatTensorTy = op.getType().cast(); + assert(rank0FloatTensorTy.getSizes().size() == 0 && + "Op should have rank 0 tensor type"); + + bool unbiased; + if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) { + return rewriter.notifyMatchFailure( + op, "Only support constant unbiased for aten.var"); + } + + Value dtype = rewriter.create(loc); + Value mean = + rewriter.create(loc, rank0FloatTensorTy, self, dtype); + Value subMean = createTensorSub(rewriter, loc, inputTensorTy, self, mean); + Value square = rewriter.create(loc, inputTensorTy, subMean); + Value var; + if (unbiased) { + // Bessel’s correction is used. Divide the square sum by + // numTensorElements-1. + Value squareSum = + rewriter.create(loc, rank0FloatTensorTy, square, dtype); + Value numTensorElements = rewriter.create(loc, square); + Value cst1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value numTensorElementsSub1 = + rewriter.create(loc, numTensorElements, cst1); + var = rewriter.replaceOpWithNewOp( + op, rank0FloatTensorTy, squareSum, numTensorElementsSub1); + } else { + var = rewriter.replaceOpWithNewOp(op, rank0FloatTensorTy, + square, dtype); + } + return success(); + } +}; +} // namespace + +// Decompose aten.std to sqrt(var(x)) +namespace { +class DecomposeAtenStdOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenStdOp op, + PatternRewriter &rewriter) const override { + Value self = op.self(); + BaseTensorType inputTensorTy = self.getType().cast(); + if (!inputTensorTy.hasDtype() || + !inputTensorTy.getDtype().isa()) { + return rewriter.notifyMatchFailure(op, + "Only aten.std support floating type"); + } + Value var = rewriter.create(op->getLoc(), op.getType(), + op.self(), op.unbiased()); + rewriter.replaceOpWithNewOp(op, op.getType(), var); + return success(); + } +}; +} // namespace + namespace { template class DecomposeAtenAddCLikeOp : public OpRewritePattern { @@ -730,6 +818,12 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index c401e247a..6ba896d6a 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -242,7 +242,7 @@ public: AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, - AtenAbsOp, AtenThresholdOp>(op)) { + AtenAbsOp, AtenThresholdOp, AtenSquareOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } @@ -469,6 +469,9 @@ public: } else if (auto max = dyn_cast(op)) { Type dtype = operands[0]->getValue().dtype; return visitReductionAlongAllDimsOp(max, dtype, operands); + } else if (isa(op)) { + auto input = operands[0]->getValue(); + return visitReductionAlongAllDimsOp(op, input.dtype, operands); } else if (auto softmaxIntOp = dyn_cast(op)) { return visitAtenSoftmaxLikeOp(softmaxIntOp, operands); } else if (auto _softmaxOp = dyn_cast(op)) { diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 432afb511..ff00ea2e1 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -482,6 +482,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::reciprocal : (Tensor) -> (Tensor)", "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", + "aten::square : (Tensor) -> (Tensor)", ]: emit_with_mutating_variants(key) @@ -538,6 +539,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::sqrt : (Tensor) -> (Tensor)") emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") + emit("aten::std : (Tensor, bool) -> (Tensor)") + emit("aten::var : (Tensor, bool) -> (Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") # Misc tensor ops. diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index bf4172b8d..be59e1343 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -185,3 +185,103 @@ func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt %0 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64> return %0 : !torch.vtensor<[],si64> } + +// ----- +// CHECK-LABEL: func @torch.aten.square( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[SQUARE]] : !torch.vtensor<[?,?,?],f32> +func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// ----- +// CHECK-LABEL: func @torch.aten.var$unbiased( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true +// CHECK: %[[DTYPE:.*]] = torch.constant.none +// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[NUM_ELEMENTS_SUB1:.*]] = torch.aten.sub.int %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]], %[[CST1]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> +// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> +func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { + %true = torch.constant.bool true + %0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- +// CHECK-LABEL: func @torch.aten.var$biased( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false +// CHECK: %[[DTYPE:.*]] = torch.constant.none +// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int +// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> +// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32> +func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { + %false = torch.constant.bool false + %0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- +// CHECK-LABEL: func @torch.aten.std$unbiased( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true +// CHECK: %[[DTYPE:.*]] = torch.constant.none +// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[NUM_ELEMENTS_SUB1:.*]] = torch.aten.sub.int %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]], %[[CST1]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> +// CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32> +func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { + %true = torch.constant.bool true + %0 = torch.aten.std %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- +// CHECK-LABEL: func @torch.aten.std$biased( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false +// CHECK: %[[DTYPE:.*]] = torch.constant.none +// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int +// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> +// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> +// CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32> +func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { + %false = torch.constant.bool false + %0 = torch.aten.std %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +}