[TORCH] Add aten.std e2e support

pull/550/head
Yi Zhang 2022-01-29 12:10:50 -05:00
parent e58b66bc3b
commit 5d9a15263a
7 changed files with 362 additions and 5 deletions

View File

@ -1006,6 +1006,7 @@ class TModuleRank2(torch.nn.Module):
def TModuleRank2_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class TModuleRank1(torch.nn.Module):
def __init__(self):
@ -1023,6 +1024,7 @@ class TModuleRank1(torch.nn.Module):
def TModuleRank1_basic(module, tu: TestUtils):
module.forward(tu.rand(3))
# ==============================================================================
class TModuleRank0(torch.nn.Module):
def __init__(self):
@ -1040,6 +1042,8 @@ class TModuleRank0(torch.nn.Module):
def TModuleRank0_basic(module, tu: TestUtils):
module.forward(torch.tensor(7, dtype=torch.float32))
# ==============================================================================
class TensorLiteralModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -1057,6 +1061,7 @@ class TensorLiteralModule(torch.nn.Module):
def TensorLiteralModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class TensorOpaqueLiteralModule(torch.nn.Module):
def __init__(self):
@ -1075,6 +1080,8 @@ class TensorOpaqueLiteralModule(torch.nn.Module):
def TensorOpaqueLiteralModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ReturnTwoTensorF32I64(torch.nn.Module):
def __init__(self):
super().__init__()
@ -1092,6 +1099,7 @@ class ReturnTwoTensorF32I64(torch.nn.Module):
def ReturnTwoTensorF32I64_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), torch.randint(5, (2, 3)))
# ==============================================================================
class IndexTensorModule(torch.nn.Module):
def __init__(self):
@ -1109,3 +1117,93 @@ class IndexTensorModule(torch.nn.Module):
@register_test_case(module_factory=lambda: IndexTensorModule())
def IndexTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5), torch.randint(4, (2, 3)))
# ==============================================================================
class SquareModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.square(x)
@register_test_case(module_factory=lambda: SquareModule())
def SquareModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class VarUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, unbiased=True)
@register_test_case(module_factory=lambda: VarUnbiasedModule())
def VarUnbiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class VarBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, unbiased=False)
@register_test_case(module_factory=lambda: VarBiasedModule())
def VarBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class StdUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, unbiased=True)
@register_test_case(module_factory=lambda: StdUnbiasedModule())
def StdUnbiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class StdBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, unbiased=False)
@register_test_case(module_factory=lambda: StdBiasedModule())
def StdBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))

View File

@ -89,4 +89,5 @@ TOSA_PASS_SET = {
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
}

View File

@ -1172,6 +1172,34 @@ def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($threshold)) `,` qualified(type($value)) `->` qualified(type($result))";
}
def Torch_AtenSquareOp : Torch_Op<"aten.square", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::square : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_AtenSquare_Op : Torch_Op<"aten.square_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::square_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement,
HasValueSemantics
@ -1759,6 +1787,36 @@ def Torch_AtenMeanOp : Torch_Op<"aten.mean", [
let assemblyFormat = "$self `,` $dtype attr-dict `:` qualified(type($self)) `,` qualified(type($dtype)) `->` qualified(type($result))";
}
def Torch_AtenStdOp : Torch_Op<"aten.std", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::std : (Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_BoolType:$unbiased
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $unbiased attr-dict `:` qualified(type($self)) `,` qualified(type($unbiased)) `->` qualified(type($result))";
}
def Torch_AtenVarOp : Torch_Op<"aten.var", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::var : (Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_BoolType:$unbiased
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $unbiased attr-dict `:` qualified(type($self)) `,` qualified(type($unbiased)) `->` qualified(type($result))";
}
def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -403,7 +403,7 @@ public:
};
} // namespace
// Decompose torch.matmul into: torch.mm and torch.bmm according to ranks.
// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks.
namespace {
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
public:
@ -459,7 +459,7 @@ public:
};
} // namespace
// Decompose torch.expand into torch.broadcast_to op.
// Decompose aten.expand into aten.broadcast_to op.
namespace {
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
public:
@ -479,7 +479,7 @@ public:
};
} // namespace
// Decompose torch.addmm into torch.mm and torch.add.Tensor op.
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
namespace {
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
public:
@ -519,7 +519,7 @@ public:
};
} // namespace
// Decompose torch.mean into: sum(x)/div(numTensorElements).
// Decompose aten.mean into: sum(x)/div(numTensorElements).
namespace {
class DecomposeAtenMeanOp : public OpRewritePattern<AtenMeanOp> {
public:
@ -539,6 +539,94 @@ public:
};
} // namespace
namespace {
class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSquareOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), self, self);
return success();
}
};
} // namespace
// Decompose aten.var into: sum(square(x - mean))/(numTensorElements-1)
// for unbiased and mean(square(x - mean)) for biased case.
namespace {
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(op,
"Only aten.var support floating type");
}
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
assert(rank0FloatTensorTy.getSizes().size() == 0 &&
"Op should have rank 0 tensor type");
bool unbiased;
if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) {
return rewriter.notifyMatchFailure(
op, "Only support constant unbiased for aten.var");
}
Value dtype = rewriter.create<ConstantNoneOp>(loc);
Value mean =
rewriter.create<AtenMeanOp>(loc, rank0FloatTensorTy, self, dtype);
Value subMean = createTensorSub(rewriter, loc, inputTensorTy, self, mean);
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
Value var;
if (unbiased) {
// Bessels correction is used. Divide the square sum by
// numTensorElements-1.
Value squareSum =
rewriter.create<AtenSumOp>(loc, rank0FloatTensorTy, square, dtype);
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, square);
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value numTensorElementsSub1 =
rewriter.create<AtenSubIntOp>(loc, numTensorElements, cst1);
var = rewriter.replaceOpWithNewOp<AtenDivScalarOp>(
op, rank0FloatTensorTy, squareSum, numTensorElementsSub1);
} else {
var = rewriter.replaceOpWithNewOp<AtenMeanOp>(op, rank0FloatTensorTy,
square, dtype);
}
return success();
}
};
} // namespace
// Decompose aten.std to sqrt(var(x))
namespace {
class DecomposeAtenStdOp : public OpRewritePattern<AtenStdOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenStdOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(op,
"Only aten.std support floating type");
}
Value var = rewriter.create<AtenVarOp>(op->getLoc(), op.getType(),
op.self(), op.unbiased());
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), var);
return success();
}
};
} // namespace
namespace {
template<typename OpTy, typename T1T2Op>
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
@ -730,6 +818,12 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenArangeStartOp>();
patterns.add<DecomposeAtenArgMaxOp>(context);
target.addIllegalOp<AtenArgmaxOp>();
patterns.add<DecomposeAtenSquareOp>(context);
target.addIllegalOp<AtenSquareOp>();
patterns.add<DecomposeAtenVarOp>(context);
target.addIllegalOp<AtenVarOp>();
patterns.add<DecomposeAtenStdOp>(context);
target.addIllegalOp<AtenStdOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -242,7 +242,7 @@ public:
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp>(op)) {
AtenAbsOp, AtenThresholdOp, AtenSquareOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}
@ -469,6 +469,9 @@ public:
} else if (auto max = dyn_cast<AtenMaxOp>(op)) {
Type dtype = operands[0]->getValue().dtype;
return visitReductionAlongAllDimsOp(max, dtype, operands);
} else if (isa<AtenStdOp, AtenVarOp>(op)) {
auto input = operands[0]->getValue();
return visitReductionAlongAllDimsOp(op, input.dtype, operands);
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {

View File

@ -482,6 +482,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::reciprocal : (Tensor) -> (Tensor)",
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)",
]:
emit_with_mutating_variants(key)
@ -538,6 +539,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::sqrt : (Tensor) -> (Tensor)")
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
emit("aten::mean : (Tensor, int?) -> (Tensor)")
emit("aten::std : (Tensor, bool) -> (Tensor)")
emit("aten::var : (Tensor, bool) -> (Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
# Misc tensor ops.

View File

@ -185,3 +185,103 @@ func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
%0 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64>
}
// -----
// CHECK-LABEL: func @torch.aten.square(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[SQUARE]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.var$unbiased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[NUM_ELEMENTS_SUB1:.*]] = torch.aten.sub.int %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32>
func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%true = torch.constant.bool true
%0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.var$biased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32>
func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%false = torch.constant.bool false
%0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.std$unbiased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[NUM_ELEMENTS_SUB1:.*]] = torch.aten.sub.int %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32>
func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%true = torch.constant.bool true
%0 = torch.aten.std %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.std$biased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32>
func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%false = torch.constant.bool false
%0 = torch.aten.std %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}