[MLIR][TORCH] Add E2E support for aten.var_mean.dim op

This commit adds the decomposition for the aten.var_mean.dim op.

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/2070/head
Vivek Khandelwal 2023-04-26 07:14:06 +00:00
parent c8e062fb4e
commit 491ae5eda4
8 changed files with 115 additions and 0 deletions

View File

@ -945,4 +945,6 @@ LTC_XFAIL_SET = {
"PrimsViewOfModule_basic", "PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic", "PrimsViewOfZeroRankModule_basic",
"OneHotModule_basic", "OneHotModule_basic",
"VarMeanDimModule_basic",
"VarMeanDimBiasedModule_basic",
} }

View File

@ -5385,6 +5385,33 @@ def Torch_AtenVarMeanOp : Torch_Op<"aten.var_mean", [
}]; }];
} }
def Torch_AtenVarMeanDimOp : Torch_Op<"aten.var_mean.dim", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::var_mean.dim : (Tensor, int[]?, bool, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$unbiased,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenVarMeanDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 2);
}
void AtenVarMeanDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 2);
}
}];
}
def Torch_AtenNllLoss2dForwardOp : Torch_Op<"aten.nll_loss2d_forward", [ def Torch_AtenNllLoss2dForwardOp : Torch_Op<"aten.nll_loss2d_forward", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -6501,6 +6501,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n" " %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n" " return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.var_mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.tuple<int, int> {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.var_mean\"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n" " func.func @\"__torch_mlir_shape_fn.aten.var_mean\"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n" " %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n" " %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"

View File

@ -4298,6 +4298,28 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
} // namespace } // namespace
namespace {
// Decompose `aten.var_mean.dim` op into `aten.var.dim` and
// `aten.mean.dim` op.
class DecomposeAtenVarMeanDimOp : public OpRewritePattern<AtenVarMeanDimOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarMeanDimOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value var = rewriter.create<AtenVarDimOp>(loc, op.getType(0), op.getSelf(),
op.getDim(), op.getUnbiased(),
op.getKeepdim());
Value mean = rewriter.create<AtenMeanDimOp>(
loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(),
/*dtype=*/noneVal);
rewriter.replaceOp(op, {var, mean});
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeComplexOpsPass class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> { : public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -4460,6 +4482,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
GreedyRewriteConfig config; GreedyRewriteConfig config;
config.useTopDownTraversal = true; config.useTopDownTraversal = true;

View File

@ -476,6 +476,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenMovedimIntOp>(); target.addIllegalOp<AtenMovedimIntOp>();
target.addIllegalOp<AtenOneHotOp>(); target.addIllegalOp<AtenOneHotOp>();
target.addIllegalOp<AtenCrossEntropyLossOp>(); target.addIllegalOp<AtenCrossEntropyLossOp>();
target.addIllegalOp<AtenVarMeanDimOp>();
for (auto &opName : backendLegalOpsSet) { for (auto &opName : backendLegalOpsSet) {
target.addLegalOp( target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context)); OperationName(kTorchOpPrefix + opName.first().str(), context));

View File

@ -332,6 +332,14 @@ def atenvar_meancorrection〡shape(self: List[int], dim: Optional[List[int
out = upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) out = upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
return out, out return out, out
def atenvar_meandim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> Tuple[List[int], List[int]]:
out = upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
return out, out
def atenvar_meandim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> Tuple[int, int]:
_, self_dtype = self_rank_dtype
return self_dtype, self_dtype
def atenvar_mean〡shape(self: List[int], unbiased: bool = True) -> Tuple[List[int], List[int]]: def atenvar_mean〡shape(self: List[int], unbiased: bool = True) -> Tuple[List[int], List[int]]:
return [], [] return [], []

View File

@ -422,6 +422,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)") emit("aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)")
emit("aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)") emit("aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)")
emit("aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)") emit("aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)")
emit("aten::var_mean.dim : (Tensor, int[]?, bool, bool) -> (Tensor, Tensor)")
emit("aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")

View File

@ -1000,3 +1000,44 @@ class VarMeanBiasedModule(torch.nn.Module):
@register_test_case(module_factory=lambda: VarMeanBiasedModule()) @register_test_case(module_factory=lambda: VarMeanBiasedModule())
def VarMeanBiasedModule_basic(module, tu: TestUtils): def VarMeanBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7)) module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class VarMeanDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var_mean(x, dim=[1])
@register_test_case(module_factory=lambda: VarMeanDimModule())
def VarMeanDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
class VarMeanDimBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var_mean(x, dim=[1], unbiased=False, keepdim=True)
@register_test_case(module_factory=lambda: VarMeanDimBiasedModule())
def VarMeanDimBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))