[MLIR][TORCH] Fix mean and mean.dim op for large-sized inputs

This commit fixes the aten.mean and aten.mean.dim op decomposition
for supporting large-sized inputs.
This commit also fixes the formatting for the file stats.py

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1597/head
Vivek Khandelwal 2022-11-17 18:10:02 +05:30
parent ed901094c1
commit 55c7e66aa7
3 changed files with 151 additions and 49 deletions

View File

@ -1366,11 +1366,18 @@ public:
Value input = op.self();
Value output = op.result();
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
Value sum =
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.dtype());
Type outputTensorTypeAsF64 = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());
Value sum = rewriter.create<AtenSumOp>(
loc, outputTensorTypeAsF64, input,
rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(
(int)getScalarTypeForType(rewriter.getF64Type()))));
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputTensorType, sum,
numTensorElements);
Value mean = rewriter.create<AtenDivScalarOp>(loc, outputTensorTypeAsF64,
sum, numTensorElements);
rewriter.replaceOp(op, convertTensorToDtype(rewriter, loc, mean,
outputTensorType.getDtype()));
return success();
}
};
@ -1390,7 +1397,10 @@ public:
Value dimList = op.dim();
Value keepDim = op.keepdim();
Value dtype = op.dtype();
Type outputType = op.getType();
BaseTensorType outputTensorType =
op.result().getType().cast<BaseTensorType>();
Type outputTensorTypeAsF64 = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());
MLIRContext *context = op.getContext();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
@ -1409,7 +1419,7 @@ public:
// Compute sum along dimensions specified in `dimList`.
Value sumAlongDims = rewriter.create<AtenSumDimIntListOp>(
loc, outputType, input, dimList, keepDim, dtype);
loc, outputTensorTypeAsF64, input, dimList, keepDim, dtype);
// `productDimSize` is product of sizes of dimensions to be reduced.
Value productDimSize;
@ -1425,8 +1435,11 @@ public:
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
}
}
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
productDimSize);
Value meanDim = rewriter.create<AtenDivScalarOp>(
loc, outputTensorTypeAsF64, sumAlongDims, productDimSize);
rewriter.replaceOp(op, convertTensorToDtype(rewriter, loc, meanDim,
outputTensorType.getDtype()));
return success();
}
};
@ -2752,7 +2765,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
Type outputType = op.getType();
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
Type newOutputType = outputTensorType.getWithSizesAndDtype(
Type outputTensorTypeAsF64 = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
@ -2802,8 +2815,9 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
if (!unbiased) {
Value result = rewriter.create<AtenMeanDimOp>(
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
Value result = rewriter.create<AtenMeanDimOp>(loc, outputTensorTypeAsF64,
square, dimList, keepDim,
/*dtype=*/constantNone);
result = convertTensorToDtype(rewriter, loc, result,
outputTensorType.getDtype());
rewriter.replaceOp(op, result);
@ -2811,7 +2825,8 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
}
// Divide the square sum by productDimSize - correction.
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
loc, outputTensorTypeAsF64, square, dimList, keepDim,
/*dtype=*/constantNone);
// `productDimSize` is product of sizes of dimensions to be reduced.
Value constantOne =
@ -2835,8 +2850,8 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
"correction value should be less than or equal to productDimSize + 1");
Value productDimSizeSubCorrection =
rewriter.create<AtenSubIntOp>(loc, productDimSize, cstCorrection);
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
productDimSizeSubCorrection);
Value result = rewriter.create<AtenDivScalarOp>(
loc, outputTensorTypeAsF64, squareSum, productDimSizeSubCorrection);
result =
convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype());
rewriter.replaceOp(op, result);

View File

@ -68,6 +68,25 @@ def MeanDtypeModule_basic(module, tu: TestUtils):
# ==============================================================================
class MeanLargeInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x)
@register_test_case(module_factory=lambda: MeanLargeInputModule())
def MeanLargeInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
# ==============================================================================
class MeanDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -87,6 +106,26 @@ def MeanDimModule_basic(module, tu: TestUtils):
# ==============================================================================
class MeanDimLargeInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (0, 2))
@register_test_case(module_factory=lambda: MeanDimLargeInputModule())
def MeanDimLargeInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
# ==============================================================================
class MeanDimDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -531,7 +570,7 @@ class VarDimAllDimReduceModule(torch.nn.Module):
@register_test_case(module_factory=lambda: VarDimAllDimReduceModule())
def VarDimAllDimReduceModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
module.forward(tu.rand(3, 128, 1024, low=100, high=200))
# ==============================================================================
@ -754,7 +793,7 @@ class VarCorrectionLargeInputModule(torch.nn.Module):
@register_test_case(module_factory=lambda: VarCorrectionLargeInputModule())
def VarCorrectionLargeInputModule_basic(module, tu: TestUtils):
module.forward(100 + tu.rand(3, 4, 1024, 8192))
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
# ==============================================================================
@ -768,7 +807,7 @@ class VarMeanCorrectionModule(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var_mean(x, dim=[1, 2], correction=2, keepdim=True)
@ -776,7 +815,7 @@ class VarMeanCorrectionModule(torch.nn.Module):
@register_test_case(module_factory=lambda: VarMeanCorrectionModule())
def VarMeanCorrectionModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
# ==============================================================================

View File

@ -248,8 +248,12 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
@ -261,9 +265,9 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens
// CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32>
func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
%true = torch.constant.bool true
@ -292,8 +296,12 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
@ -302,10 +310,14 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
// CHECK: %[[CST7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[BIASED_VAR_CAST:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_2]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32>
func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
%false = torch.constant.bool false
@ -334,8 +346,12 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
@ -347,9 +363,9 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
// CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32>
func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
@ -379,8 +395,12 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
@ -389,10 +409,14 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
// CHECK: %[[CST7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[BIASED_VAR_CAST:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_2]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32>
func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
@ -819,18 +843,26 @@ func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !t
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[VAR_CAST:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_2]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
// CHECK: %[[NONE_3:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32>
func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
%int2 = torch.constant.int 2
@ -877,8 +909,12 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
@ -891,9 +927,9 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -
// CHECK: %[[NUM_ELEMENTS_MINUS_CORRECTION:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_MINUS_CORRECTION]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_2]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32>
func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
%int2 = torch.constant.int 2
@ -921,18 +957,26 @@ func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !tor
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,5],f64>
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,5],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,5],f64> -> !torch.vtensor<[3,4,5],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,5],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST7_2:.*]] = torch.constant.int 7
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
// CHECK: %[[NONE_2:.*]] = torch.constant.none
// CHECK: %[[VAR_CAST:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_2]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
// CHECK: %[[NONE_3:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: %[[STD:.*]] = torch.aten.sqrt %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> -> !torch.vtensor<[3,4,1],f32>
// CHECK: return %[[STD]] : !torch.vtensor<[3,4,1],f32>
func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,1],f32> {
@ -1019,10 +1063,14 @@ func.func @torch.aten.mse_loss$no_reduction(%arg0: !torch.vtensor<[?,?],f32>, %a
// CHECK: %[[SUB_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[SUB]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
// CHECK: %[[NUMEL:.*]] = torch.aten.numel %[[SUB_SQUARE]] : !torch.vtensor<[?,?],f32> -> !torch.int
// CHECK: %[[SUB_SQUARE_MEAN:.*]] = torch.aten.div.Scalar %[[SUB_SQUARE_SUM]], %[[NUMEL]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[SUB_SQUARE_MEAN]] : !torch.vtensor<[?,?],f32>
// CHECK: %[[SUB_SQUARE_MEAN:.*]] = torch.aten.div.Scalar %[[SUB_SQUARE_SUM]], %[[NUMEL]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.vtensor<[?,?],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[SUB_SQUARE_MEAN_CAST:.*]] = torch.aten.to.dtype %[[SUB_SQUARE_MEAN]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[SUB_SQUARE_MEAN_CAST]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.mse_loss$mean_reduction(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.mse_loss %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>