mirror of https://github.com/llvm/torch-mlir
[TORCH] Add aten.std e2e support
parent
e58b66bc3b
commit
5d9a15263a
|
@ -1006,6 +1006,7 @@ class TModuleRank2(torch.nn.Module):
|
|||
def TModuleRank2_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TModuleRank1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1023,6 +1024,7 @@ class TModuleRank1(torch.nn.Module):
|
|||
def TModuleRank1_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TModuleRank0(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1040,6 +1042,8 @@ class TModuleRank0(torch.nn.Module):
|
|||
def TModuleRank0_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor(7, dtype=torch.float32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TensorLiteralModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1057,6 +1061,7 @@ class TensorLiteralModule(torch.nn.Module):
|
|||
def TensorLiteralModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TensorOpaqueLiteralModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1075,6 +1080,8 @@ class TensorOpaqueLiteralModule(torch.nn.Module):
|
|||
def TensorOpaqueLiteralModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReturnTwoTensorF32I64(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1092,6 +1099,7 @@ class ReturnTwoTensorF32I64(torch.nn.Module):
|
|||
def ReturnTwoTensorF32I64_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.randint(5, (2, 3)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class IndexTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1109,3 +1117,93 @@ class IndexTensorModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: IndexTensorModule())
|
||||
def IndexTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5), torch.randint(4, (2, 3)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SquareModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.square(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: SquareModule())
|
||||
def SquareModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class VarUnbiasedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: VarUnbiasedModule())
|
||||
def VarUnbiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class VarBiasedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: VarBiasedModule())
|
||||
def VarBiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class StdUnbiasedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: StdUnbiasedModule())
|
||||
def StdUnbiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class StdBiasedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: StdBiasedModule())
|
||||
def StdBiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
|
|
@ -89,4 +89,5 @@ TOSA_PASS_SET = {
|
|||
"FlattenStaticModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"SquareModule_basic",
|
||||
}
|
||||
|
|
|
@ -1172,6 +1172,34 @@ def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [
|
|||
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($threshold)) `,` qualified(type($value)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenSquareOp : Torch_Op<"aten.square", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::square : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenSquare_Op : Torch_Op<"aten.square_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::square_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -1759,6 +1787,36 @@ def Torch_AtenMeanOp : Torch_Op<"aten.mean", [
|
|||
let assemblyFormat = "$self `,` $dtype attr-dict `:` qualified(type($self)) `,` qualified(type($dtype)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenStdOp : Torch_Op<"aten.std", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::std : (Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_BoolType:$unbiased
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $unbiased attr-dict `:` qualified(type($self)) `,` qualified(type($unbiased)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenVarOp : Torch_Op<"aten.var", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::var : (Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_BoolType:$unbiased
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $unbiased attr-dict `:` qualified(type($self)) `,` qualified(type($unbiased)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -403,7 +403,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose torch.matmul into: torch.mm and torch.bmm according to ranks.
|
||||
// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks.
|
||||
namespace {
|
||||
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
|
||||
public:
|
||||
|
@ -459,7 +459,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose torch.expand into torch.broadcast_to op.
|
||||
// Decompose aten.expand into aten.broadcast_to op.
|
||||
namespace {
|
||||
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
||||
public:
|
||||
|
@ -479,7 +479,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose torch.addmm into torch.mm and torch.add.Tensor op.
|
||||
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
|
||||
namespace {
|
||||
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
|
||||
public:
|
||||
|
@ -519,7 +519,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose torch.mean into: sum(x)/div(numTensorElements).
|
||||
// Decompose aten.mean into: sum(x)/div(numTensorElements).
|
||||
namespace {
|
||||
class DecomposeAtenMeanOp : public OpRewritePattern<AtenMeanOp> {
|
||||
public:
|
||||
|
@ -539,6 +539,94 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSquareOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.self();
|
||||
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), self, self);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.var into: sum(square(x - mean))/(numTensorElements-1)
|
||||
// for unbiased and mean(square(x - mean)) for biased case.
|
||||
namespace {
|
||||
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenVarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
if (!inputTensorTy.hasDtype() ||
|
||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Only aten.var support floating type");
|
||||
}
|
||||
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
|
||||
assert(rank0FloatTensorTy.getSizes().size() == 0 &&
|
||||
"Op should have rank 0 tensor type");
|
||||
|
||||
bool unbiased;
|
||||
if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support constant unbiased for aten.var");
|
||||
}
|
||||
|
||||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value mean =
|
||||
rewriter.create<AtenMeanOp>(loc, rank0FloatTensorTy, self, dtype);
|
||||
Value subMean = createTensorSub(rewriter, loc, inputTensorTy, self, mean);
|
||||
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
|
||||
Value var;
|
||||
if (unbiased) {
|
||||
// Bessel’s correction is used. Divide the square sum by
|
||||
// numTensorElements-1.
|
||||
Value squareSum =
|
||||
rewriter.create<AtenSumOp>(loc, rank0FloatTensorTy, square, dtype);
|
||||
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, square);
|
||||
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value numTensorElementsSub1 =
|
||||
rewriter.create<AtenSubIntOp>(loc, numTensorElements, cst1);
|
||||
var = rewriter.replaceOpWithNewOp<AtenDivScalarOp>(
|
||||
op, rank0FloatTensorTy, squareSum, numTensorElementsSub1);
|
||||
} else {
|
||||
var = rewriter.replaceOpWithNewOp<AtenMeanOp>(op, rank0FloatTensorTy,
|
||||
square, dtype);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.std to sqrt(var(x))
|
||||
namespace {
|
||||
class DecomposeAtenStdOp : public OpRewritePattern<AtenStdOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenStdOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.self();
|
||||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
if (!inputTensorTy.hasDtype() ||
|
||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Only aten.std support floating type");
|
||||
}
|
||||
Value var = rewriter.create<AtenVarOp>(op->getLoc(), op.getType(),
|
||||
op.self(), op.unbiased());
|
||||
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), var);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template<typename OpTy, typename T1T2Op>
|
||||
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
||||
|
@ -730,6 +818,12 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenArangeStartOp>();
|
||||
patterns.add<DecomposeAtenArgMaxOp>(context);
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
patterns.add<DecomposeAtenSquareOp>(context);
|
||||
target.addIllegalOp<AtenSquareOp>();
|
||||
patterns.add<DecomposeAtenVarOp>(context);
|
||||
target.addIllegalOp<AtenVarOp>();
|
||||
patterns.add<DecomposeAtenStdOp>(context);
|
||||
target.addIllegalOp<AtenStdOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -242,7 +242,7 @@ public:
|
|||
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenThresholdOp>(op)) {
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
@ -469,6 +469,9 @@ public:
|
|||
} else if (auto max = dyn_cast<AtenMaxOp>(op)) {
|
||||
Type dtype = operands[0]->getValue().dtype;
|
||||
return visitReductionAlongAllDimsOp(max, dtype, operands);
|
||||
} else if (isa<AtenStdOp, AtenVarOp>(op)) {
|
||||
auto input = operands[0]->getValue();
|
||||
return visitReductionAlongAllDimsOp(op, input.dtype, operands);
|
||||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
|
||||
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
|
||||
|
|
|
@ -482,6 +482,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
"aten::reciprocal : (Tensor) -> (Tensor)",
|
||||
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
"aten::square : (Tensor) -> (Tensor)",
|
||||
|
||||
]:
|
||||
emit_with_mutating_variants(key)
|
||||
|
@ -538,6 +539,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::sqrt : (Tensor) -> (Tensor)")
|
||||
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
|
||||
emit("aten::mean : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::std : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
|
|
|
@ -185,3 +185,103 @@ func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
|
|||
%0 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
|
||||
return %0 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.square(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[SQUARE]] : !torch.vtensor<[?,?,?],f32>
|
||||
func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.var$unbiased(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[NUM_ELEMENTS_SUB1:.*]] = torch.aten.sub.int %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32>
|
||||
func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%true = torch.constant.bool true
|
||||
%0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.var$biased(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
|
||||
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32>
|
||||
func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.std$unbiased(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[NUM_ELEMENTS_SUB1:.*]] = torch.aten.sub.int %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32>
|
||||
func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%true = torch.constant.bool true
|
||||
%0 = torch.aten.std %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.std$biased(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.sum %[[INPUT]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.numel %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
|
||||
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum %[[SUB_MEAN_SQUARE]], %[[DTYPE]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
|
||||
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32>
|
||||
func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.std %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue