[MLIR][TORCH] Add decomposition for aten.var.correction op

This commit adds the decomposition for `aten.var.correction` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com
pull/1122/head
Vivek Khandelwal 2022-07-22 18:12:14 +05:30
parent 7247c6a3a7
commit d386b8f9e5
8 changed files with 510 additions and 128 deletions

View File

@ -4047,6 +4047,32 @@ def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [
}];
}
def Torch_AtenVarCorrectionOp : Torch_Op<"aten.var.correction", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim,
AnyTorchOptionalIntType:$correction,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenVarCorrectionOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenVarCorrectionOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -229,8 +229,7 @@ public:
} // namespace
namespace {
class DecomposeAtenZeroOp
: public OpRewritePattern<AtenZeroOp> {
class DecomposeAtenZeroOp : public OpRewritePattern<AtenZeroOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenZeroOp op,
@ -705,7 +704,8 @@ public:
// unsqueezed_sizes += [1, s]
// expanded_sizes += [m, s]
// reshaped_sizes += [m * s]
// return self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
// return
// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
//
namespace {
class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
@ -754,7 +754,8 @@ public:
assert(leadingRank >= 0 && "leadingRank should greater than 0");
for (size_t i = 0; i < leadingRank; ++i) {
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one});
insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef<Value>{repeats[i]});
insertDimSizes(expandedSizes, expandedIntSizes,
ArrayRef<Value>{repeats[i]});
reshapedSizes.push_back(repeats[i]);
}
@ -772,18 +773,20 @@ public:
loc, rewriter.getI64IntegerAttr(selfShape[i]));
}
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one, dimSize});
insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef<Value>{scale, dimSize});
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes,
ArrayRef<Value>{one, dimSize});
insertDimSizes(expandedSizes, expandedIntSizes,
ArrayRef<Value>{scale, dimSize});
Value scaledSize = rewriter.create<AtenMulIntOp>(loc, dimSize, scale);
reshapedSizes.push_back(scaledSize);
}
Type dtype = self.getType().cast<ValueTensorType>().getDtype();
Type unsqueezedType =
ValueTensorType::get(context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType =
ValueTensorType::get(context, llvm::makeArrayRef(expandedIntSizes), dtype);
Type unsqueezedType = ValueTensorType::get(
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedIntSizes), dtype);
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value unsqueezedDims =
@ -792,8 +795,8 @@ public:
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
Value reshapedDims =
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
auto reshaped =
rewriter.create<AtenViewOp>(loc, unsqueezedType, op.self(), unsqueezedDims);
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType, op.self(),
unsqueezedDims);
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
reshaped, expandedDims);
@ -1184,14 +1187,14 @@ public:
Value input = op.self();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
Value inputTimesBeta = rewriter.create<AtenMulScalarOp>(
loc, inputType, input, op.beta());
Value inputTimesBeta =
rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.beta());
// out = log1p(exp(input * beta)) / beta
Value exp = rewriter.create<AtenExpOp>(loc, inputType, inputTimesBeta);
Value log1p = rewriter.create<AtenLog1pOp>(loc, inputType, exp);
Value out = rewriter.create<AtenDivScalarOp>(
loc, inputType, log1p, op.beta());
Value out =
rewriter.create<AtenDivScalarOp>(loc, inputType, log1p, op.beta());
// Select where x * beta > threshold
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
@ -1199,8 +1202,8 @@ public:
Value condition = rewriter.create<AtenGtScalarOp>(
loc, boolResType, inputTimesBeta, op.threshold());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(
op, op.getType(), condition, input, out);
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), condition,
input, out);
return success();
}
};
@ -2138,6 +2141,107 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
};
} // namespace
template <typename OpTy>
static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
bool unbiased, int64_t correction) {
Location loc = op.getLoc();
Value self = op.self();
Value dimList = op.dim();
Value keepDim = op.keepdim();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
Type outputType = op.getType();
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
Type newOutputType = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "support floating-point type input only");
}
// Upcasting the input tensor to `F64` dtype for higher precision during the
// computation of the result.
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
inputTensorTy = self.getType().cast<BaseTensorType>();
}
unsigned inputRank = getTensorRank(self);
SmallVector<Value> dimListElements;
bool isNoneOrEmpty = true;
if (!dimList.getType().template isa<Torch::NoneType>()) {
if (!getListConstructElements(dimList, dimListElements))
return rewriter.notifyMatchFailure(
op, "expect dimList to be constructed from list construct");
if (!dimListElements.empty() || inputRank == 0)
isNoneOrEmpty = false;
}
if (isNoneOrEmpty) {
for (unsigned i = 0; i < inputRank; i++)
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
dimListElements);
}
Type meanDimResultType = inputTensorTy;
for (unsigned i = 0; i < dimListElements.size(); i++)
meanDimResultType = computeReductionType(
rewriter, op, meanDimResultType.cast<BaseTensorType>(),
dimListElements[i],
/*keepDim=*/true);
Value constantNone = rewriter.create<ConstantNoneOp>(loc);
Value constantTrue = rewriter.create<ConstantBoolOp>(loc, true);
Value meanAlongDims = rewriter.create<AtenMeanDimOp>(
loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue,
/*dtype=*/constantNone);
Value subMean =
createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims);
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
if (!unbiased) {
Value result = rewriter.create<AtenMeanDimOp>(
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
result = convertTensorToDtype(rewriter, loc, result,
outputTensorType.getDtype());
rewriter.replaceOp(op, result);
return success();
}
// Divide the square sum by productDimSize - correction.
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
// `productDimSize` is product of sizes of dimensions to be reduced.
Value constantOne =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value productDimSize = constantOne;
for (Value dim : dimListElements) {
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
productDimSize =
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
}
Value cstCorrection = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(correction));
// The `correction` value should be less than or equal to `productDimSize +
// 1`.
Value productDimSizePlusOne =
rewriter.create<AtenAddIntOp>(loc, productDimSize, constantOne);
Value cond =
rewriter.create<AtenGeIntOp>(loc, productDimSizePlusOne, cstCorrection);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"correction value should be less than or equal to productDimSize + 1");
Value productDimSizeSubCorrection =
rewriter.create<AtenSubIntOp>(loc, productDimSize, cstCorrection);
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
productDimSizeSubCorrection);
result =
convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype());
rewriter.replaceOp(op, result);
return success();
}
// Decompose aten.var(x, dims) into:
// sub = aten.sub(x, aten.mean(x, dims))
// square = aten.square(sub)
@ -2151,70 +2255,44 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarDimOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value dimList = op.dim();
Value keepDim = op.keepdim();
Type outputType = op.getType();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(op,
"support floating type input only");
}
auto dimListConstruct = dimList.getDefiningOp<PrimListConstructOp>();
if (!dimListConstruct) {
return rewriter.notifyMatchFailure(
op, "expect dimList to be constructed from list construct");
}
bool unbiased;
if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) {
return rewriter.notifyMatchFailure(
op, "Only support constant unbiased for aten.var");
}
int64_t correction = unbiased ? 1 : 0;
if (failed(calculateVariance<AtenVarDimOp>(op, rewriter, unbiased,
correction)))
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
return success();
}
};
} // namespace
SmallVector<Value> dimListElements = dimListConstruct.elements();
Type meanDimResultType = inputTensorTy;
for (unsigned i = 0; i < dimListElements.size(); i++)
meanDimResultType = computeReductionType(
rewriter, op, meanDimResultType.cast<BaseTensorType>(),
dimListElements[i],
/*keepDim=*/true);
Value constantNone = rewriter.create<ConstantNoneOp>(loc);
Value constantTrue = rewriter.create<ConstantBoolOp>(loc, true);
Value meanAlongDims = rewriter.create<AtenMeanDimOp>(
loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue,
/*dtype=*/constantNone);
Value subMean =
createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims);
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
if (unbiased) {
// Bessels correction is used. Divide the square sum by
// productDimSize-1.
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
loc, outputType, square, dimList, keepDim, /*dtype=*/constantNone);
// `productDimSize` is product of sizes of dimensions to be reduced.
Value productDimSize = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
for (Value dim : dimListConstruct.elements()) {
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
productDimSize =
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
}
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value productDimSizeSubOne =
rewriter.create<AtenSubIntOp>(loc, productDimSize, constantOne);
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, squareSum,
productDimSizeSubOne);
// Decompose aten.var(x, dims) into:
// sub = aten.sub(x, aten.mean(x, dims))
// square = aten.square(sub)
// out = aten.sum(square, dims) / (productDimSize - correction)
namespace {
class DecomposeAtenVarCorrectionOp
: public OpRewritePattern<AtenVarCorrectionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarCorrectionOp op,
PatternRewriter &rewriter) const override {
int64_t correction;
if (!op.correction().getType().isa<Torch::NoneType>()) {
if (!matchPattern(op.correction(), m_TorchConstantInt(&correction)))
return rewriter.notifyMatchFailure(
op, "Only support constant int correction for aten.var");
} else {
rewriter.replaceOpWithNewOp<AtenMeanDimOp>(
op, outputType, square, dimList, keepDim, /*dtype=*/constantNone);
// The default value in case of `correction` being None is 1.
correction = 1;
}
bool unbiased = correction == 0 ? false : true;
if (failed(calculateVariance<AtenVarCorrectionOp>(op, rewriter, unbiased,
correction)))
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
return success();
}
};
@ -2426,6 +2504,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenSelectScatterOp>();
patterns.add<DecomposeAtenVarDimOp>(context);
target.addIllegalOp<AtenVarDimOp>();
patterns.add<DecomposeAtenVarCorrectionOp>(context);
target.addIllegalOp<AtenVarCorrectionOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -931,7 +931,7 @@ void TypeAnalysis::visitOperation(Operation *op,
Type dtype = operands[0]->getValue().dtype;
visitReductionAlongAllDimsOp(max, dtype, operands);
return;
} else if (isa<AtenStdOp, AtenVarOp, AtenVarDimOp>(op)) {
} else if (isa<AtenStdOp, AtenVarOp, AtenVarDimOp, AtenVarCorrectionOp>(op)) {
auto input = operands[0]->getValue();
visitReductionAlongAllDimsOp(op, input.dtype, operands);
return;

View File

@ -5570,6 +5570,36 @@ module {
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
%3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -492,6 +492,11 @@ def atenvar(self: List[int], unbiased: bool = True) -> List[int]:
def atenvardim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
def atenvarcorrection(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
def atenstd(self: List[int], unbiased: bool = True) -> List[int]:
return []

View File

@ -383,6 +383,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::std : (Tensor, bool) -> (Tensor)")
emit("aten::var : (Tensor, bool) -> (Tensor)")
emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)")
emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")

View File

@ -427,3 +427,160 @@ class VarDimKeepDimFalseModule(torch.nn.Module):
@register_test_case(module_factory=lambda: VarDimKeepDimFalseModule())
def VarDimKeepDimFalseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class VarCorrectionModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=None, correction=2)
@register_test_case(module_factory=lambda: VarCorrectionModule())
def VarCorrectionModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class VarCorrectionSingleDimReduceModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[1], correction=1)
@register_test_case(module_factory=lambda: VarCorrectionSingleDimReduceModule())
def VarCorrectionSingleDimReduceModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class VarCorrectionAllDimReduceModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x,
dim=[0, 1, 2],
correction=10,
keepdim=False)
@register_test_case(module_factory=lambda: VarCorrectionAllDimReduceModule())
def VarCorrectionAllDimReduceModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class VarCorrectionKeepDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[0, 1], correction=None, keepdim=True)
@register_test_case(module_factory=lambda: VarCorrectionKeepDimModule())
def VarCorrectionKeepDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class VarCorrectionNoneModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=None, correction=None)
@register_test_case(module_factory=lambda: VarCorrectionNoneModule())
def VarCorrectionNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class VarCorrectionEmptyDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[], correction=2)
@register_test_case(module_factory=lambda: VarCorrectionEmptyDimModule())
def VarCorrectionEmptyDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class VarCorrectionLargeInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[2, 3], correction=2)
@register_test_case(module_factory=lambda: VarCorrectionLargeInputModule())
def VarCorrectionLargeInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 1024, 8192))

View File

@ -233,28 +233,36 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[CST1_2:.*]] = torch.constant.int 1
// CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32>
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32>
func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
%true = torch.constant.bool true
%0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
@ -269,26 +277,34 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32>
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32>
func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
%false = torch.constant.bool false
%0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
@ -303,28 +319,36 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[CST1_2:.*]] = torch.constant.int 1
// CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32>
func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
%true = torch.constant.bool true
@ -340,26 +364,34 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32>
func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
%false = torch.constant.bool false
@ -1165,22 +1197,30 @@ func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !t
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32>
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,1],f32>, !torch.float -> !torch.vtensor<[3,4,7],f32>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,7],f32> -> !torch.vtensor<[3,4,7],f32>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32>
// CHECK: return %[[VAR]] : !torch.vtensor<[3,4,1],f32>
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32>
func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
%int2 = torch.constant.int 2
%dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
@ -1208,3 +1248,46 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -
%ret = torch.aten.softplus %t, %dim, %int0: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32>
return %ret : !torch.tensor<[2,3],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.var.correction(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
// CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[CST2_0:.*]] = torch.constant.int 2
// CHECK: %[[NUM_ELEMENTS_PLUS_ONE:.*]] = torch.aten.add.int %[[NUM_ELEMENTS_0]], %[[CST1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[PRED:.*]] = torch.aten.ge.int %[[NUM_ELEMENTS_PLUS_ONE]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[PRED]], "correction value should be less than or equal to productDimSize + 1"
// CHECK: %[[NUM_ELEMENTS_MINUS_CORRECTION:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_MINUS_CORRECTION]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32>
func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
%int2 = torch.constant.int 2
%dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%keepdim = torch.constant.bool true
%0 = torch.aten.var.correction %arg0, %dims, %int2, %keepdim: !torch.vtensor<[3,4,7],f32>, !torch.list<int>, !torch.int, !torch.bool -> !torch.vtensor<[3,4,1],f32>
return %0 : !torch.vtensor<[3,4,1],f32>
}