mirror of https://github.com/llvm/torch-mlir
Add aten.std.correction op and its decomposition (#1731)
parent
50b524546f
commit
60a139271d
|
@ -628,7 +628,6 @@ LTC_XFAIL_SET = {
|
|||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AddIntModule_basic",
|
||||
"BernoulliFloatModule_basic",
|
||||
"BernoulliModule_basic",
|
||||
"BernoulliTensorModule_basic",
|
||||
"BincountMinlengthModule_basic",
|
||||
"BincountModule_basic",
|
||||
|
@ -639,7 +638,6 @@ LTC_XFAIL_SET = {
|
|||
"BoolIntTrueModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"DivFloatModule_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
||||
"ElementwiseAtenFloorDivideModule_basic",
|
||||
"EqIntModule_basic",
|
||||
|
@ -712,13 +710,6 @@ LTC_XFAIL_SET = {
|
|||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"SqrtIntModule_basic",
|
||||
"StdBiasedModule_basic",
|
||||
"StdDimBiasedModule_basic",
|
||||
"StdDimKeepDimFalseModule_basic",
|
||||
"StdDimKeepDimTrueModule_basic",
|
||||
"StdDimEmptyDimModule_basic",
|
||||
"StdDimNoneDimModule_basic",
|
||||
"StdUnbiasedModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TensorsConcatNegativeDimModule_basic",
|
||||
|
|
|
@ -4905,6 +4905,32 @@ def Torch_AtenStdDimOp : Torch_Op<"aten.std.dim", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenStdCorrectionOp : Torch_Op<"aten.std.correction", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::std.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
AnyTorchOptionalIntType:$correction,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenStdCorrectionOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenStdCorrectionOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenVarOp : Torch_Op<"aten.var", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -5839,6 +5839,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.std.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
|
|
|
@ -1710,6 +1710,32 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.std.correction to sqrt(var.correction(x))
|
||||
namespace {
|
||||
class DecomposeAtenStdCorrectionOp
|
||||
: public OpRewritePattern<AtenStdCorrectionOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenStdCorrectionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.getSelf();
|
||||
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
|
||||
if (!inputTensorType.hasDtype() ||
|
||||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"aten.std.correction expects input tensor of floating-point type");
|
||||
}
|
||||
|
||||
Value varCorrection = rewriter.create<AtenVarCorrectionOp>(
|
||||
op->getLoc(), op.getType(), self, op.getDim(), op.getCorrection(),
|
||||
op.getKeepdim());
|
||||
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varCorrection);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
|
||||
namespace {
|
||||
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
|
||||
|
@ -3511,6 +3537,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
|
||||
|
|
|
@ -429,6 +429,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenAmaxOp>();
|
||||
target.addIllegalOp<AtenVarCorrectionOp>();
|
||||
target.addIllegalOp<AtenStdDimOp>();
|
||||
target.addIllegalOp<AtenStdCorrectionOp>();
|
||||
target.addIllegalOp<AtenNarrowOp>();
|
||||
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
||||
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
||||
|
|
|
@ -976,8 +976,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
Type dtype = operands[0]->getValue().dtype;
|
||||
visitReductionAlongAllDimsOp(op, dtype, operands);
|
||||
return;
|
||||
} else if (isa<AtenStdOp, AtenStdDimOp, AtenVarOp, AtenVarDimOp,
|
||||
AtenVarCorrectionOp>(op)) {
|
||||
} else if (isa<AtenStdOp, AtenStdDimOp, AtenStdCorrectionOp, AtenVarOp,
|
||||
AtenVarDimOp, AtenVarCorrectionOp>(op)) {
|
||||
auto input = operands[0]->getValue();
|
||||
visitReductionAlongAllDimsOp(op, input.dtype, operands);
|
||||
return;
|
||||
|
|
|
@ -320,6 +320,9 @@ def aten〇std〡shape(self: List[int], unbiased: bool = True) -> List[int]:
|
|||
def aten〇std〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[int] = None, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
||||
dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self))
|
||||
out: List[int] = []
|
||||
|
|
|
@ -405,6 +405,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::mean : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::std : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
|
||||
emit("aten::std.correction : (Tensor, int[]?, int?, bool) -> (Tensor)")
|
||||
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
|
||||
emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)")
|
||||
|
|
|
@ -405,6 +405,163 @@ def StdDimNoneDimModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class StdCorrectionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=None, correction=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdCorrectionModule())
|
||||
def StdCorrectionModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdCorrectionSingleDimReduceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=[1], correction=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdCorrectionSingleDimReduceModule())
|
||||
def StdCorrectionSingleDimReduceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdCorrectionAllDimReduceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x,
|
||||
dim=[0, 1, 2],
|
||||
correction=10,
|
||||
keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdCorrectionAllDimReduceModule())
|
||||
def StdCorrectionAllDimReduceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdCorrectionKeepDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=[0, 1], correction=None, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdCorrectionKeepDimModule())
|
||||
def StdCorrectionKeepDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdCorrectionNoneModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=None, correction=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdCorrectionNoneModule())
|
||||
def StdCorrectionNoneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdCorrectionEmptyDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=[], correction=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdCorrectionEmptyDimModule())
|
||||
def StdCorrectionEmptyDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdCorrectionLargeInputModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=[2, 3], correction=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdCorrectionLargeInputModule())
|
||||
def StdCorrectionLargeInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 1024, 8192, low=100.0, high=101.0))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -754,7 +911,7 @@ class VarCorrectionLargeInputModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionLargeInputModule())
|
||||
def VarCorrectionLargeInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(100 + tu.rand(3, 4, 1024, 8192))
|
||||
module.forward(tu.rand(3, 4, 1024, 8192, low=100.0, high=101.0))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
Loading…
Reference in New Issue