[MLIR][TORCH] Add e2e support for aten.var_mean op

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1715/head
Vivek Khandelwal 2022-12-09 20:52:26 +05:30
parent 143a8f378d
commit d4862ec611
9 changed files with 106 additions and 3 deletions

View File

@ -811,4 +811,6 @@ LTC_XFAIL_SET = {
"NllLossModule_sum_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
"VarMeanBiasedModule_basic",
"VarMeanUnbiasedModule_basic",
}

View File

@ -5008,6 +5008,31 @@ def Torch_AtenVarMeanCorrectionOp : Torch_Op<"aten.var_mean.correction", [
}];
}
def Torch_AtenVarMeanOp : Torch_Op<"aten.var_mean", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_BoolType:$unbiased
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenVarMeanOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 2);
}
void AtenVarMeanOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 2);
}
}];
}
def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -3284,6 +3284,26 @@ public:
};
} // namespace
namespace {
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarMeanOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value var = rewriter.create<AtenVarDimOp>(loc, op.getType(0), op.getSelf(),
/*dim=*/noneVal, op.getUnbiased(),
/*keepdim=*/falseVal);
Value mean = rewriter.create<AtenMeanOp>(loc, op.getType(0), op.getSelf(),
/*dtype=*/noneVal);
rewriter.replaceOp(op, {var, mean});
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -3428,6 +3448,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;

View File

@ -418,6 +418,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<PrimsConvertElementTypeOp>();
target.addIllegalOp<AtenRandnOp>();
target.addIllegalOp<AtenRandnGeneratorOp>();
target.addIllegalOp<AtenVarMeanOp>();
for (std::string opName : backendLegalOps) {
target.addLegalOp(OperationName(opName, context));
}

View File

@ -1182,7 +1182,7 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}
if (isa<AtenVarMeanCorrectionOp>(op)) {
if (isa<AtenVarMeanCorrectionOp, AtenVarMeanOp>(op)) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());

View File

@ -5780,6 +5780,12 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.var_mean\"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.std\"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"

View File

@ -544,13 +544,16 @@ def atenvar(self: List[int], unbiased: bool = True) -> List[int]:
def atenvardim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def atenvarcorrection(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
def atenvarcorrection(self: List[int], dim: Optional[List[int]] = None, correction: Optional[int] = None, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
def atenvar_meancorrection(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> Tuple[List[int], List[int]]:
def atenvar_meancorrection(self: List[int], dim: Optional[List[int]] = None, correction: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]:
out = upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
return out, out
def atenvar_mean(self: List[int], unbiased: bool = True) -> Tuple[List[int], List[int]]:
return [], []
def atenstd(self: List[int], unbiased: bool = True) -> List[int]:
return []

View File

@ -409,6 +409,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)")
emit("aten::var_mean.correction : (Tensor, int[]?, int?, bool) -> (Tensor, Tensor)")
emit("aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")

View File

@ -799,3 +799,47 @@ class VarMeanCorrectionNoneModule(torch.nn.Module):
@register_test_case(module_factory=lambda: VarMeanCorrectionNoneModule())
def VarMeanCorrectionNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class VarMeanUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var_mean(x)
@register_test_case(module_factory=lambda: VarMeanUnbiasedModule())
def VarMeanUnbiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class VarMeanBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var_mean(x, unbiased=False)
@register_test_case(module_factory=lambda: VarMeanBiasedModule())
def VarMeanBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))