mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix mean and mean.dim op for large-sized inputs
This commit fixes the aten.mean and aten.mean.dim op decomposition for supporting large-sized inputs. This commit also fixes the formatting for the file stats.py Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1597/head
parent
ed901094c1
commit
55c7e66aa7
|
@ -1366,11 +1366,18 @@ public:
|
|||
Value input = op.self();
|
||||
Value output = op.result();
|
||||
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
|
||||
Value sum =
|
||||
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.dtype());
|
||||
Type outputTensorTypeAsF64 = outputTensorType.getWithSizesAndDtype(
|
||||
outputTensorType.getSizes(), rewriter.getF64Type());
|
||||
Value sum = rewriter.create<AtenSumOp>(
|
||||
loc, outputTensorTypeAsF64, input,
|
||||
rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(
|
||||
(int)getScalarTypeForType(rewriter.getF64Type()))));
|
||||
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
|
||||
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputTensorType, sum,
|
||||
numTensorElements);
|
||||
Value mean = rewriter.create<AtenDivScalarOp>(loc, outputTensorTypeAsF64,
|
||||
sum, numTensorElements);
|
||||
rewriter.replaceOp(op, convertTensorToDtype(rewriter, loc, mean,
|
||||
outputTensorType.getDtype()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1390,7 +1397,10 @@ public:
|
|||
Value dimList = op.dim();
|
||||
Value keepDim = op.keepdim();
|
||||
Value dtype = op.dtype();
|
||||
Type outputType = op.getType();
|
||||
BaseTensorType outputTensorType =
|
||||
op.result().getType().cast<BaseTensorType>();
|
||||
Type outputTensorTypeAsF64 = outputTensorType.getWithSizesAndDtype(
|
||||
outputTensorType.getSizes(), rewriter.getF64Type());
|
||||
MLIRContext *context = op.getContext();
|
||||
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
|
@ -1409,7 +1419,7 @@ public:
|
|||
|
||||
// Compute sum along dimensions specified in `dimList`.
|
||||
Value sumAlongDims = rewriter.create<AtenSumDimIntListOp>(
|
||||
loc, outputType, input, dimList, keepDim, dtype);
|
||||
loc, outputTensorTypeAsF64, input, dimList, keepDim, dtype);
|
||||
|
||||
// `productDimSize` is product of sizes of dimensions to be reduced.
|
||||
Value productDimSize;
|
||||
|
@ -1425,8 +1435,11 @@ public:
|
|||
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
|
||||
productDimSize);
|
||||
Value meanDim = rewriter.create<AtenDivScalarOp>(
|
||||
loc, outputTensorTypeAsF64, sumAlongDims, productDimSize);
|
||||
rewriter.replaceOp(op, convertTensorToDtype(rewriter, loc, meanDim,
|
||||
outputTensorType.getDtype()));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2752,7 +2765,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
Type outputType = op.getType();
|
||||
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
|
||||
Type newOutputType = outputTensorType.getWithSizesAndDtype(
|
||||
Type outputTensorTypeAsF64 = outputTensorType.getWithSizesAndDtype(
|
||||
outputTensorType.getSizes(), rewriter.getF64Type());
|
||||
if (!inputTensorTy.hasDtype() ||
|
||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||||
|
@ -2802,8 +2815,9 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
|
||||
|
||||
if (!unbiased) {
|
||||
Value result = rewriter.create<AtenMeanDimOp>(
|
||||
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
||||
Value result = rewriter.create<AtenMeanDimOp>(loc, outputTensorTypeAsF64,
|
||||
square, dimList, keepDim,
|
||||
/*dtype=*/constantNone);
|
||||
result = convertTensorToDtype(rewriter, loc, result,
|
||||
outputTensorType.getDtype());
|
||||
rewriter.replaceOp(op, result);
|
||||
|
@ -2811,7 +2825,8 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
}
|
||||
// Divide the square sum by productDimSize - correction.
|
||||
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
|
||||
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
||||
loc, outputTensorTypeAsF64, square, dimList, keepDim,
|
||||
/*dtype=*/constantNone);
|
||||
|
||||
// `productDimSize` is product of sizes of dimensions to be reduced.
|
||||
Value constantOne =
|
||||
|
@ -2835,8 +2850,8 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
"correction value should be less than or equal to productDimSize + 1");
|
||||
Value productDimSizeSubCorrection =
|
||||
rewriter.create<AtenSubIntOp>(loc, productDimSize, cstCorrection);
|
||||
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
|
||||
productDimSizeSubCorrection);
|
||||
Value result = rewriter.create<AtenDivScalarOp>(
|
||||
loc, outputTensorTypeAsF64, squareSum, productDimSizeSubCorrection);
|
||||
result =
|
||||
convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype());
|
||||
rewriter.replaceOp(op, result);
|
||||
|
|
|
@ -68,6 +68,25 @@ def MeanDtypeModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanLargeInputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanLargeInputModule())
|
||||
def MeanLargeInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -87,6 +106,26 @@ def MeanDimModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDimLargeInputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 2))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDimLargeInputModule())
|
||||
def MeanDimLargeInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDimDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -531,7 +570,7 @@ class VarDimAllDimReduceModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: VarDimAllDimReduceModule())
|
||||
def VarDimAllDimReduceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
module.forward(tu.rand(3, 128, 1024, low=100, high=200))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -754,7 +793,7 @@ class VarCorrectionLargeInputModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionLargeInputModule())
|
||||
def VarCorrectionLargeInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(100 + tu.rand(3, 4, 1024, 8192))
|
||||
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -768,7 +807,7 @@ class VarMeanCorrectionModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var_mean(x, dim=[1, 2], correction=2, keepdim=True)
|
||||
|
@ -776,7 +815,7 @@ class VarMeanCorrectionModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: VarMeanCorrectionModule())
|
||||
def VarMeanCorrectionModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -248,8 +248,12 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens
|
|||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
|
||||
|
@ -261,9 +265,9 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens
|
|||
// CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%true = torch.constant.bool true
|
||||
|
@ -292,8 +296,12 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
|
|||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
|
||||
|
@ -302,10 +310,14 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
|
|||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[BIASED_VAR_CAST:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_2]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%false = torch.constant.bool false
|
||||
|
@ -334,8 +346,12 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
|
|||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
|
||||
|
@ -347,9 +363,9 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
|
|||
// CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
|
@ -379,8 +395,12 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
|
|||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
|
||||
|
@ -389,10 +409,14 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
|
|||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[BIASED_VAR_CAST:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_2]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
|
@ -819,18 +843,26 @@ func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !t
|
|||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAR_CAST:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_2]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_3:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32>
|
||||
func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
|
@ -877,8 +909,12 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -
|
|||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
|
||||
|
@ -891,9 +927,9 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -
|
|||
// CHECK: %[[NUM_ELEMENTS_MINUS_CORRECTION:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_MINUS_CORRECTION]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_2]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32>
|
||||
func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
|
@ -921,18 +957,26 @@ func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !tor
|
|||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST7_1:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[MEAN_CAST:.*]] = torch.aten.to.dtype %[[MEAN]], %[[CST7_1]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,5],f64>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN_CAST]], %[[ALPHA]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,5],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,5],f64> -> !torch.vtensor<[3,4,5],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,5],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST7_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE_1:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAR_CAST:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST7_2]], %[[FALSE_1]], %[[FALSE_1]], %[[NONE_2]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_3:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR_CAST]], %[[CST6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: %[[STD:.*]] = torch.aten.sqrt %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: return %[[STD]] : !torch.vtensor<[3,4,1],f32>
|
||||
func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,1],f32> {
|
||||
|
@ -1019,10 +1063,14 @@ func.func @torch.aten.mse_loss$no_reduction(%arg0: !torch.vtensor<[?,?],f32>, %a
|
|||
// CHECK: %[[SUB_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[SUB]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[NUMEL:.*]] = torch.aten.numel %[[SUB_SQUARE]] : !torch.vtensor<[?,?],f32> -> !torch.int
|
||||
// CHECK: %[[SUB_SQUARE_MEAN:.*]] = torch.aten.div.Scalar %[[SUB_SQUARE_SUM]], %[[NUMEL]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[SUB_SQUARE_MEAN]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_SQUARE_MEAN:.*]] = torch.aten.div.Scalar %[[SUB_SQUARE_SUM]], %[[NUMEL]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[SUB_SQUARE_MEAN_CAST:.*]] = torch.aten.to.dtype %[[SUB_SQUARE_MEAN]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[SUB_SQUARE_MEAN_CAST]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.mse_loss$mean_reduction(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.mse_loss %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
|
|
Loading…
Reference in New Issue