[Torch Dialect] support AtenArangeStartOutOp in ReduceOpVariants like… (#2563)

… AtenBernoulli_FloatOp

It fixing case like: `%2110 = torch.aten.arange.start_out %int1,
%int1517, %int1, %2109 : !torch.int, !torch.int, !torch.int,
!torch.tensor -> !torch.tensor`.
`aten.arange.start_out` doesn't have value semantics also, means`%2110`
is an alias for %2109.
So I decompose it to `aten.arange.start` + `torch.contents.overwrite`.  
The complex decomposition logic is target to handle cases like view and
dtype cast which I add in e2e tests.
rollpytorch snapshot-20231117.1025
Yuanqiang Liu 2023-11-17 00:51:55 +08:00 committed by GitHub
parent dad1f012f6
commit facbe5d96b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 8 deletions

View File

@ -191,6 +191,16 @@ private:
// Reduce Ops without value semantics but the corresponding without trailing
// underscore variant doesn't exist.
namespace {
// int(ceil((end - start) / step))
Value calculateArangeResultNumElements(PatternRewriter &rewriter, Location loc,
Value start, Value end, Value step) {
Value sub = rewriter.create<AtenSubOp>(
loc, Torch::NumberType::get(rewriter.getContext()), end, start);
Value div = rewriter.create<AtenDivOp>(loc, sub, step);
return rewriter.create<AtenCeilFloatOp>(loc, div);
}
class ReduceNonValueSemanticOps : public RewritePattern {
public:
ReduceNonValueSemanticOps(MLIRContext *context)
@ -198,19 +208,54 @@ public:
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Operation *newOp;
MLIRContext *ctx = op->getContext();
if (isa<AtenBernoulli_FloatOp>(op)) {
newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
Operation *newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
loc, op->getResultTypes(), op->getOperands());
} else {
return failure();
}
auto tensor =
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));
return success();
} else if (auto arangeOutOp = dyn_cast<AtenArangeStartOutOp>(op)) {
Value start = arangeOutOp.getStart();
Value end = arangeOutOp.getEnd();
Value step = arangeOutOp.getStep();
Value out = arangeOutOp.getOut();
// `overwrite.tensor.contents` cannot change the tensor shape,
// so `out` tensor should have same num_elements with result tensor.
// It means that we don't support code like:
// `x = torch.randn(12)`
// `y = torch.arange(13, out=x)`
Value resultNumElements =
calculateArangeResultNumElements(rewriter, loc, start, end, step);
Value outNumElements = rewriter.create<AtenNumelOp>(loc, out);
Value eqOrNot =
rewriter.create<AtenEqIntOp>(loc, resultNumElements, outNumElements);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("`out` tensor should have the same "
"num_elements with result tenosr"));
auto dtype = rewriter.create<PrimDtypeOp>(loc, out);
auto device = rewriter.create<PrimDeviceOp>(loc, out);
auto shape = rewriter.create<AtenSizeOp>(
loc, Torch::ListType::get(Torch::IntType::get(ctx)), out);
auto none = rewriter.create<ConstantNoneOp>(loc);
Value newArange = rewriter.create<AtenArangeStartStepOp>(
loc, arangeOutOp.getResult().getType(), start, end, step, dtype,
/*layout=*/none, device, /*pin_memory=*/none);
Value reshape = rewriter.create<AtenReshapeOp>(
loc, arangeOutOp.getResult().getType(), newArange, shape);
auto vtensor = rewriter.create<CopyToValueTensorOp>(loc, reshape);
createOverwriteTensorContents(rewriter, loc, vtensor, out);
rewriter.replaceOp(arangeOutOp, out);
return success();
} else {
return failure();
}
}
};
} // namespace
@ -309,6 +354,7 @@ struct ReduceOpVariantsPass
ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.addIllegalOp<AtenArangeStartOutOp>();
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||

View File

@ -302,6 +302,9 @@ TORCHDYNAMO_XFAIL_SET = {
# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
"ThresholdBackward2dMixedModule_basic",
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
"ArangeStartOutViewModule_basic",
}
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
@ -1303,6 +1306,8 @@ TOSA_PASS_SET = {
"AtenEyeModuleFalsePinMemory_basic",
"AtenEyeModuleFloat2D_basic",
"MeanModule_basic",
"ArangeStartOutModule_basic",
"ArangeStartOutViewModule_basic",
}
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
@ -1372,6 +1377,7 @@ LTC_XFAIL_SET = {
"_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic",
"AddIntModule_basic",
"ArangeStartOutViewModule_basic",
"AtenIntBoolOpModule_basic",
"BernoulliTensorModule_basic",
"BincountMinlengthModule_basic",

View File

@ -248,3 +248,53 @@ class ArangeFalsePinMemoryModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ArangeStartOutModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([12], torch.int64, True),
])
def forward(self, x):
return torch.arange(start=0, end=12, out=x)
@register_test_case(module_factory=lambda: ArangeStartOutModule())
def ArangeStartOutModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(12).to(torch.int64))
class ArangeStartOutViewModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int64, True),
])
def forward(self, x):
return torch.arange(start=1, end=13, out=x)
@register_test_case(module_factory=lambda: ArangeStartOutViewModule())
def ArangeStartOutViewModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(3, 4).to(torch.int64))
class ArangeStartOutDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([12], torch.int64, True),
])
def forward(self, x):
return torch.arange(start=1.1, end=13.1, out=x)
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(12).to(torch.int64))