mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add decomposition for aten.var.correction op
This commit adds the decomposition for `aten.var.correction` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.compull/1122/head
parent
7247c6a3a7
commit
d386b8f9e5
|
@ -4047,6 +4047,32 @@ def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenVarCorrectionOp : Torch_Op<"aten.var.correction", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
AnyTorchOptionalIntType:$correction,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenVarCorrectionOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenVarCorrectionOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -229,8 +229,7 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenZeroOp
|
||||
: public OpRewritePattern<AtenZeroOp> {
|
||||
class DecomposeAtenZeroOp : public OpRewritePattern<AtenZeroOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenZeroOp op,
|
||||
|
@ -705,7 +704,8 @@ public:
|
|||
// unsqueezed_sizes += [1, s]
|
||||
// expanded_sizes += [m, s]
|
||||
// reshaped_sizes += [m * s]
|
||||
// return self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
|
||||
// return
|
||||
// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
|
||||
//
|
||||
namespace {
|
||||
class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
|
||||
|
@ -754,7 +754,8 @@ public:
|
|||
assert(leadingRank >= 0 && "leadingRank should greater than 0");
|
||||
for (size_t i = 0; i < leadingRank; ++i) {
|
||||
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one});
|
||||
insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef<Value>{repeats[i]});
|
||||
insertDimSizes(expandedSizes, expandedIntSizes,
|
||||
ArrayRef<Value>{repeats[i]});
|
||||
reshapedSizes.push_back(repeats[i]);
|
||||
}
|
||||
|
||||
|
@ -772,18 +773,20 @@ public:
|
|||
loc, rewriter.getI64IntegerAttr(selfShape[i]));
|
||||
}
|
||||
|
||||
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one, dimSize});
|
||||
insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef<Value>{scale, dimSize});
|
||||
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes,
|
||||
ArrayRef<Value>{one, dimSize});
|
||||
insertDimSizes(expandedSizes, expandedIntSizes,
|
||||
ArrayRef<Value>{scale, dimSize});
|
||||
|
||||
Value scaledSize = rewriter.create<AtenMulIntOp>(loc, dimSize, scale);
|
||||
reshapedSizes.push_back(scaledSize);
|
||||
}
|
||||
|
||||
Type dtype = self.getType().cast<ValueTensorType>().getDtype();
|
||||
Type unsqueezedType =
|
||||
ValueTensorType::get(context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
|
||||
Type expandedType =
|
||||
ValueTensorType::get(context, llvm::makeArrayRef(expandedIntSizes), dtype);
|
||||
Type unsqueezedType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
|
||||
Type expandedType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedIntSizes), dtype);
|
||||
|
||||
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
||||
Value unsqueezedDims =
|
||||
|
@ -792,8 +795,8 @@ public:
|
|||
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
|
||||
Value reshapedDims =
|
||||
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
|
||||
auto reshaped =
|
||||
rewriter.create<AtenViewOp>(loc, unsqueezedType, op.self(), unsqueezedDims);
|
||||
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType, op.self(),
|
||||
unsqueezedDims);
|
||||
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
|
||||
reshaped, expandedDims);
|
||||
|
||||
|
@ -1184,14 +1187,14 @@ public:
|
|||
Value input = op.self();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
|
||||
Value inputTimesBeta = rewriter.create<AtenMulScalarOp>(
|
||||
loc, inputType, input, op.beta());
|
||||
Value inputTimesBeta =
|
||||
rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.beta());
|
||||
|
||||
// out = log1p(exp(input * beta)) / beta
|
||||
Value exp = rewriter.create<AtenExpOp>(loc, inputType, inputTimesBeta);
|
||||
Value log1p = rewriter.create<AtenLog1pOp>(loc, inputType, exp);
|
||||
Value out = rewriter.create<AtenDivScalarOp>(
|
||||
loc, inputType, log1p, op.beta());
|
||||
Value out =
|
||||
rewriter.create<AtenDivScalarOp>(loc, inputType, log1p, op.beta());
|
||||
|
||||
// Select where x * beta > threshold
|
||||
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
|
||||
|
@ -1199,8 +1202,8 @@ public:
|
|||
Value condition = rewriter.create<AtenGtScalarOp>(
|
||||
loc, boolResType, inputTimesBeta, op.threshold());
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(
|
||||
op, op.getType(), condition, input, out);
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), condition,
|
||||
input, out);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2138,6 +2141,107 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
template <typename OpTy>
|
||||
static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
||||
bool unbiased, int64_t correction) {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
Value dimList = op.dim();
|
||||
Value keepDim = op.keepdim();
|
||||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
Type outputType = op.getType();
|
||||
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
|
||||
Type newOutputType = outputTensorType.getWithSizesAndDtype(
|
||||
outputTensorType.getSizes(), rewriter.getF64Type());
|
||||
if (!inputTensorTy.hasDtype() ||
|
||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "support floating-point type input only");
|
||||
}
|
||||
|
||||
// Upcasting the input tensor to `F64` dtype for higher precision during the
|
||||
// computation of the result.
|
||||
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
|
||||
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
|
||||
inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
}
|
||||
|
||||
unsigned inputRank = getTensorRank(self);
|
||||
SmallVector<Value> dimListElements;
|
||||
bool isNoneOrEmpty = true;
|
||||
if (!dimList.getType().template isa<Torch::NoneType>()) {
|
||||
if (!getListConstructElements(dimList, dimListElements))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expect dimList to be constructed from list construct");
|
||||
if (!dimListElements.empty() || inputRank == 0)
|
||||
isNoneOrEmpty = false;
|
||||
}
|
||||
if (isNoneOrEmpty) {
|
||||
for (unsigned i = 0; i < inputRank; i++)
|
||||
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||
dimListElements);
|
||||
}
|
||||
Type meanDimResultType = inputTensorTy;
|
||||
for (unsigned i = 0; i < dimListElements.size(); i++)
|
||||
meanDimResultType = computeReductionType(
|
||||
rewriter, op, meanDimResultType.cast<BaseTensorType>(),
|
||||
dimListElements[i],
|
||||
/*keepDim=*/true);
|
||||
|
||||
Value constantNone = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value constantTrue = rewriter.create<ConstantBoolOp>(loc, true);
|
||||
Value meanAlongDims = rewriter.create<AtenMeanDimOp>(
|
||||
loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue,
|
||||
/*dtype=*/constantNone);
|
||||
Value subMean =
|
||||
createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims);
|
||||
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
|
||||
|
||||
if (!unbiased) {
|
||||
Value result = rewriter.create<AtenMeanDimOp>(
|
||||
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
||||
result = convertTensorToDtype(rewriter, loc, result,
|
||||
outputTensorType.getDtype());
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
// Divide the square sum by productDimSize - correction.
|
||||
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
|
||||
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
||||
|
||||
// `productDimSize` is product of sizes of dimensions to be reduced.
|
||||
Value constantOne =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value productDimSize = constantOne;
|
||||
for (Value dim : dimListElements) {
|
||||
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
||||
productDimSize =
|
||||
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
||||
}
|
||||
Value cstCorrection = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(correction));
|
||||
// The `correction` value should be less than or equal to `productDimSize +
|
||||
// 1`.
|
||||
Value productDimSizePlusOne =
|
||||
rewriter.create<AtenAddIntOp>(loc, productDimSize, constantOne);
|
||||
Value cond =
|
||||
rewriter.create<AtenGeIntOp>(loc, productDimSizePlusOne, cstCorrection);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, cond,
|
||||
"correction value should be less than or equal to productDimSize + 1");
|
||||
Value productDimSizeSubCorrection =
|
||||
rewriter.create<AtenSubIntOp>(loc, productDimSize, cstCorrection);
|
||||
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
|
||||
productDimSizeSubCorrection);
|
||||
result =
|
||||
convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype());
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Decompose aten.var(x, dims) into:
|
||||
// sub = aten.sub(x, aten.mean(x, dims))
|
||||
// square = aten.square(sub)
|
||||
|
@ -2151,70 +2255,44 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenVarDimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
Value dimList = op.dim();
|
||||
Value keepDim = op.keepdim();
|
||||
Type outputType = op.getType();
|
||||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
if (!inputTensorTy.hasDtype() ||
|
||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"support floating type input only");
|
||||
}
|
||||
|
||||
auto dimListConstruct = dimList.getDefiningOp<PrimListConstructOp>();
|
||||
if (!dimListConstruct) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expect dimList to be constructed from list construct");
|
||||
}
|
||||
|
||||
bool unbiased;
|
||||
if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support constant unbiased for aten.var");
|
||||
}
|
||||
int64_t correction = unbiased ? 1 : 0;
|
||||
if (failed(calculateVariance<AtenVarDimOp>(op, rewriter, unbiased,
|
||||
correction)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
SmallVector<Value> dimListElements = dimListConstruct.elements();
|
||||
Type meanDimResultType = inputTensorTy;
|
||||
for (unsigned i = 0; i < dimListElements.size(); i++)
|
||||
meanDimResultType = computeReductionType(
|
||||
rewriter, op, meanDimResultType.cast<BaseTensorType>(),
|
||||
dimListElements[i],
|
||||
/*keepDim=*/true);
|
||||
|
||||
Value constantNone = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value constantTrue = rewriter.create<ConstantBoolOp>(loc, true);
|
||||
Value meanAlongDims = rewriter.create<AtenMeanDimOp>(
|
||||
loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue,
|
||||
/*dtype=*/constantNone);
|
||||
Value subMean =
|
||||
createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims);
|
||||
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
|
||||
if (unbiased) {
|
||||
// Bessel’s correction is used. Divide the square sum by
|
||||
// productDimSize-1.
|
||||
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
|
||||
loc, outputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
||||
|
||||
// `productDimSize` is product of sizes of dimensions to be reduced.
|
||||
Value productDimSize = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
for (Value dim : dimListConstruct.elements()) {
|
||||
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
||||
productDimSize =
|
||||
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
||||
}
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value productDimSizeSubOne =
|
||||
rewriter.create<AtenSubIntOp>(loc, productDimSize, constantOne);
|
||||
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, squareSum,
|
||||
productDimSizeSubOne);
|
||||
// Decompose aten.var(x, dims) into:
|
||||
// sub = aten.sub(x, aten.mean(x, dims))
|
||||
// square = aten.square(sub)
|
||||
// out = aten.sum(square, dims) / (productDimSize - correction)
|
||||
namespace {
|
||||
class DecomposeAtenVarCorrectionOp
|
||||
: public OpRewritePattern<AtenVarCorrectionOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenVarCorrectionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
int64_t correction;
|
||||
if (!op.correction().getType().isa<Torch::NoneType>()) {
|
||||
if (!matchPattern(op.correction(), m_TorchConstantInt(&correction)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support constant int correction for aten.var");
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<AtenMeanDimOp>(
|
||||
op, outputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
||||
// The default value in case of `correction` being None is 1.
|
||||
correction = 1;
|
||||
}
|
||||
bool unbiased = correction == 0 ? false : true;
|
||||
if (failed(calculateVariance<AtenVarCorrectionOp>(op, rewriter, unbiased,
|
||||
correction)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2426,6 +2504,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenSelectScatterOp>();
|
||||
patterns.add<DecomposeAtenVarDimOp>(context);
|
||||
target.addIllegalOp<AtenVarDimOp>();
|
||||
patterns.add<DecomposeAtenVarCorrectionOp>(context);
|
||||
target.addIllegalOp<AtenVarCorrectionOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -931,7 +931,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
Type dtype = operands[0]->getValue().dtype;
|
||||
visitReductionAlongAllDimsOp(max, dtype, operands);
|
||||
return;
|
||||
} else if (isa<AtenStdOp, AtenVarOp, AtenVarDimOp>(op)) {
|
||||
} else if (isa<AtenStdOp, AtenVarOp, AtenVarDimOp, AtenVarCorrectionOp>(op)) {
|
||||
auto input = operands[0]->getValue();
|
||||
visitReductionAlongAllDimsOp(op, input.dtype, operands);
|
||||
return;
|
||||
|
|
|
@ -5570,6 +5570,36 @@ module {
|
|||
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %1 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %7 : !torch.bool
|
||||
}
|
||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %5, %true, init() {
|
||||
^bb0(%arg4: !torch.int):
|
||||
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
} else {
|
||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %5 : !torch.list<int>
|
||||
}
|
||||
%3 = torch.derefine %none : !torch.none to !torch.any
|
||||
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %4 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {
|
||||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
|
|
@ -492,6 +492,11 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]:
|
|||
def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
|
||||
if dim is None or len(dim)==0:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def aten〇std(self: List[int], unbiased: bool = True) -> List[int]:
|
||||
return []
|
||||
|
||||
|
|
|
@ -383,6 +383,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::std : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)")
|
||||
emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)")
|
||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
||||
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
|
||||
|
|
|
@ -427,3 +427,160 @@ class VarDimKeepDimFalseModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: VarDimKeepDimFalseModule())
|
||||
def VarDimKeepDimFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarCorrectionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=None, correction=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionModule())
|
||||
def VarCorrectionModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarCorrectionSingleDimReduceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[1], correction=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionSingleDimReduceModule())
|
||||
def VarCorrectionSingleDimReduceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarCorrectionAllDimReduceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x,
|
||||
dim=[0, 1, 2],
|
||||
correction=10,
|
||||
keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionAllDimReduceModule())
|
||||
def VarCorrectionAllDimReduceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarCorrectionKeepDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[0, 1], correction=None, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionKeepDimModule())
|
||||
def VarCorrectionKeepDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarCorrectionNoneModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=None, correction=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionNoneModule())
|
||||
def VarCorrectionNoneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarCorrectionEmptyDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[], correction=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionEmptyDimModule())
|
||||
def VarCorrectionEmptyDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarCorrectionLargeInputModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[2, 3], correction=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionLargeInputModule())
|
||||
def VarCorrectionLargeInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 1024, 8192))
|
||||
|
|
|
@ -233,28 +233,36 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens
|
|||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32>
|
||||
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%true = torch.constant.bool true
|
||||
%0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
|
||||
|
@ -269,26 +277,34 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
|
|||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32>
|
||||
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
|
||||
|
@ -303,28 +319,36 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
|
|||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%true = torch.constant.bool true
|
||||
|
@ -340,26 +364,34 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
|
|||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%false = torch.constant.bool false
|
||||
|
@ -1165,22 +1197,30 @@ func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !t
|
|||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,1],f32>, !torch.float -> !torch.vtensor<[3,4,7],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,7],f32> -> !torch.vtensor<[3,4,7],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: return %[[VAR]] : !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32>
|
||||
func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
|
||||
|
@ -1208,3 +1248,46 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -
|
|||
%ret = torch.aten.softplus %t, %dim, %int0: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32>
|
||||
return %ret : !torch.tensor<[2,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.var.correction(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
|
||||
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[NUM_ELEMENTS_PLUS_ONE:.*]] = torch.aten.add.int %[[NUM_ELEMENTS_0]], %[[CST1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[PRED:.*]] = torch.aten.ge.int %[[NUM_ELEMENTS_PLUS_ONE]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[PRED]], "correction value should be less than or equal to productDimSize + 1"
|
||||
// CHECK: %[[NUM_ELEMENTS_MINUS_CORRECTION:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_MINUS_CORRECTION]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
|
||||
// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32>
|
||||
func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
|
||||
%keepdim = torch.constant.bool true
|
||||
%0 = torch.aten.var.correction %arg0, %dims, %int2, %keepdim: !torch.vtensor<[3,4,7],f32>, !torch.list<int>, !torch.int, !torch.bool -> !torch.vtensor<[3,4,1],f32>
|
||||
return %0 : !torch.vtensor<[3,4,1],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue