diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2115d3aec..2623aa4fa 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1366,11 +1366,18 @@ public: Value input = op.self(); Value output = op.result(); BaseTensorType outputTensorType = output.getType().cast(); - Value sum = - rewriter.create(loc, outputTensorType, input, op.dtype()); + Type outputTensorTypeAsF64 = outputTensorType.getWithSizesAndDtype( + outputTensorType.getSizes(), rewriter.getF64Type()); + Value sum = rewriter.create( + loc, outputTensorTypeAsF64, input, + rewriter.create( + loc, rewriter.getI64IntegerAttr( + (int)getScalarTypeForType(rewriter.getF64Type())))); Value numTensorElements = rewriter.create(loc, input); - rewriter.replaceOpWithNewOp(op, outputTensorType, sum, - numTensorElements); + Value mean = rewriter.create(loc, outputTensorTypeAsF64, + sum, numTensorElements); + rewriter.replaceOp(op, convertTensorToDtype(rewriter, loc, mean, + outputTensorType.getDtype())); return success(); } }; @@ -1390,7 +1397,10 @@ public: Value dimList = op.dim(); Value keepDim = op.keepdim(); Value dtype = op.dtype(); - Type outputType = op.getType(); + BaseTensorType outputTensorType = + op.result().getType().cast(); + Type outputTensorTypeAsF64 = outputTensorType.getWithSizesAndDtype( + outputTensorType.getSizes(), rewriter.getF64Type()); MLIRContext *context = op.getContext(); BaseTensorType inputType = input.getType().cast(); @@ -1409,7 +1419,7 @@ public: // Compute sum along dimensions specified in `dimList`. Value sumAlongDims = rewriter.create( - loc, outputType, input, dimList, keepDim, dtype); + loc, outputTensorTypeAsF64, input, dimList, keepDim, dtype); // `productDimSize` is product of sizes of dimensions to be reduced. Value productDimSize; @@ -1425,8 +1435,11 @@ public: rewriter.create(loc, productDimSize, dimSize); } } - rewriter.replaceOpWithNewOp(op, outputType, sumAlongDims, - productDimSize); + Value meanDim = rewriter.create( + loc, outputTensorTypeAsF64, sumAlongDims, productDimSize); + rewriter.replaceOp(op, convertTensorToDtype(rewriter, loc, meanDim, + outputTensorType.getDtype())); + return success(); } }; @@ -2752,7 +2765,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, BaseTensorType inputTensorTy = self.getType().cast(); Type outputType = op.getType(); BaseTensorType outputTensorType = outputType.cast(); - Type newOutputType = outputTensorType.getWithSizesAndDtype( + Type outputTensorTypeAsF64 = outputTensorType.getWithSizesAndDtype( outputTensorType.getSizes(), rewriter.getF64Type()); if (!inputTensorTy.hasDtype() || !inputTensorTy.getDtype().isa()) { @@ -2802,8 +2815,9 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, Value square = rewriter.create(loc, inputTensorTy, subMean); if (!unbiased) { - Value result = rewriter.create( - loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + Value result = rewriter.create(loc, outputTensorTypeAsF64, + square, dimList, keepDim, + /*dtype=*/constantNone); result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); rewriter.replaceOp(op, result); @@ -2811,7 +2825,8 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, } // Divide the square sum by productDimSize - correction. Value squareSum = rewriter.create( - loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + loc, outputTensorTypeAsF64, square, dimList, keepDim, + /*dtype=*/constantNone); // `productDimSize` is product of sizes of dimensions to be reduced. Value constantOne = @@ -2835,8 +2850,8 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, "correction value should be less than or equal to productDimSize + 1"); Value productDimSizeSubCorrection = rewriter.create(loc, productDimSize, cstCorrection); - Value result = rewriter.create(loc, newOutputType, squareSum, - productDimSizeSubCorrection); + Value result = rewriter.create( + loc, outputTensorTypeAsF64, squareSum, productDimSizeSubCorrection); result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); rewriter.replaceOp(op, result); diff --git a/python/torch_mlir_e2e_test/test_suite/stats.py b/python/torch_mlir_e2e_test/test_suite/stats.py index 3646b30c0..73541eada 100644 --- a/python/torch_mlir_e2e_test/test_suite/stats.py +++ b/python/torch_mlir_e2e_test/test_suite/stats.py @@ -68,6 +68,25 @@ def MeanDtypeModule_basic(module, tu: TestUtils): # ============================================================================== +class MeanLargeInputModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x) + + +@register_test_case(module_factory=lambda: MeanLargeInputModule()) +def MeanLargeInputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200)) + +# ============================================================================== + class MeanDimModule(torch.nn.Module): def __init__(self): super().__init__() @@ -87,6 +106,26 @@ def MeanDimModule_basic(module, tu: TestUtils): # ============================================================================== +class MeanDimLargeInputModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, (0, 2)) + + +@register_test_case(module_factory=lambda: MeanDimLargeInputModule()) +def MeanDimLargeInputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200)) + + +# ============================================================================== + class MeanDimDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @@ -531,7 +570,7 @@ class VarDimAllDimReduceModule(torch.nn.Module): @register_test_case(module_factory=lambda: VarDimAllDimReduceModule()) def VarDimAllDimReduceModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) + module.forward(tu.rand(3, 128, 1024, low=100, high=200)) # ============================================================================== @@ -754,7 +793,7 @@ class VarCorrectionLargeInputModule(torch.nn.Module): @register_test_case(module_factory=lambda: VarCorrectionLargeInputModule()) def VarCorrectionLargeInputModule_basic(module, tu: TestUtils): - module.forward(100 + tu.rand(3, 4, 1024, 8192)) + module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200)) # ============================================================================== @@ -768,7 +807,7 @@ class VarMeanCorrectionModule(torch.nn.Module): @export @annotate_args([ None, - ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), ]) def forward(self, x): return torch.ops.aten.var_mean(x, dim=[1, 2], correction=2, keepdim=True) @@ -776,7 +815,7 @@ class VarMeanCorrectionModule(torch.nn.Module): @register_test_case(module_factory=lambda: VarMeanCorrectionModule()) def VarMeanCorrectionModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 7)) + module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200)) # ============================================================================== diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index e0d9d85c0..9002894c1 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -248,8 +248,12 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens // CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> +// CHECK: %[[CST7_1:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> // CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> // CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 @@ -261,9 +265,9 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens // CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> // CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false -// CHECK: %[[NONE_0:.*]] = torch.constant.none -// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> // CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32> func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %true = torch.constant.bool true @@ -292,8 +296,12 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> +// CHECK: %[[CST7_1:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> // CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> // CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 @@ -302,10 +310,14 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST7_2:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[BIASED_VAR_CAST:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false -// CHECK: %[[NONE_0:.*]] = torch.constant.none -// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false +// CHECK: %[[NONE_2:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_2]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> // CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32> func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %false = torch.constant.bool false @@ -334,8 +346,12 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> +// CHECK: %[[CST7_1:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> // CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> // CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 @@ -347,9 +363,9 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> // CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false -// CHECK: %[[NONE_0:.*]] = torch.constant.none -// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> // CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32> func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { @@ -379,8 +395,12 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> +// CHECK: %[[CST7_1:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> // CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> // CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 @@ -389,10 +409,14 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST7_2:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[BIASED_VAR_CAST:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false -// CHECK: %[[NONE_0:.*]] = torch.constant.none -// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false +// CHECK: %[[NONE_2:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_2]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> // CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32> func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { @@ -819,18 +843,26 @@ func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !t // CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST7_1:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> // CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64> // CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 // CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST7_2:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false +// CHECK: %[[NONE_2:.*]] = torch.constant.none +// CHECK: %[[VAR_CAST:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_2]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false -// CHECK: %[[NONE_1:.*]] = torch.constant.none -// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false +// CHECK: %[[NONE_3:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> // CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { %int2 = torch.constant.int 2 @@ -877,8 +909,12 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) - // CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST7_1:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> // CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64> // CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 @@ -891,9 +927,9 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) - // CHECK: %[[NUM_ELEMENTS_MINUS_CORRECTION:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_MINUS_CORRECTION]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false -// CHECK: %[[NONE_1:.*]] = torch.constant.none -// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false +// CHECK: %[[NONE_2:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_2]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> // CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { %int2 = torch.constant.int 2 @@ -921,18 +957,26 @@ func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !tor // CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST7_1:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,5],f64> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,5],f64> // CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,5],f64> -> !torch.vtensor<[3,4,5],f64> // CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,5],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 // CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST7_2:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false +// CHECK: %[[NONE_2:.*]] = torch.constant.none +// CHECK: %[[VAR_CAST:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_2]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false -// CHECK: %[[NONE_1:.*]] = torch.constant.none -// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false +// CHECK: %[[NONE_3:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> // CHECK: %[[STD:.*]] = torch.aten.sqrt %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> -> !torch.vtensor<[3,4,1],f32> // CHECK: return %[[STD]] : !torch.vtensor<[3,4,1],f32> func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,1],f32> { @@ -1019,10 +1063,14 @@ func.func @torch.aten.mse_loss$no_reduction(%arg0: !torch.vtensor<[?,?],f32>, %a // CHECK: %[[SUB_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[SUB]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32> +// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[NUMEL:.*]] = torch.aten.numel %[[SUB_SQUARE]] : !torch.vtensor<[?,?],f32> -> !torch.int -// CHECK: %[[SUB_SQUARE_MEAN:.*]] = torch.aten.div.Scalar %[[SUB_SQUARE_SUM]], %[[NUMEL]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[SUB_SQUARE_MEAN]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[SUB_SQUARE_MEAN:.*]] = torch.aten.div.Scalar %[[SUB_SQUARE_SUM]], %[[NUMEL]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.vtensor<[?,?],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[SUB_SQUARE_MEAN_CAST:.*]] = torch.aten.to.dtype %[[SUB_SQUARE_MEAN]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[SUB_SQUARE_MEAN_CAST]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.mse_loss$mean_reduction(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int1 = torch.constant.int 1 %0 = torch.aten.mse_loss %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>