mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] support AtenArangeStartOutOp in ReduceOpVariants like… (#2563)
… AtenBernoulli_FloatOp It fixing case like: `%2110 = torch.aten.arange.start_out %int1, %int1517, %int1, %2109 : !torch.int, !torch.int, !torch.int, !torch.tensor -> !torch.tensor`. `aten.arange.start_out` doesn't have value semantics also, means`%2110` is an alias for %2109. So I decompose it to `aten.arange.start` + `torch.contents.overwrite`. The complex decomposition logic is target to handle cases like view and dtype cast which I add in e2e tests.rollpytorch snapshot-20231117.1025
parent
dad1f012f6
commit
facbe5d96b
|
@ -191,6 +191,16 @@ private:
|
|||
// Reduce Ops without value semantics but the corresponding without trailing
|
||||
// underscore variant doesn't exist.
|
||||
namespace {
|
||||
|
||||
// int(ceil((end - start) / step))
|
||||
Value calculateArangeResultNumElements(PatternRewriter &rewriter, Location loc,
|
||||
Value start, Value end, Value step) {
|
||||
Value sub = rewriter.create<AtenSubOp>(
|
||||
loc, Torch::NumberType::get(rewriter.getContext()), end, start);
|
||||
Value div = rewriter.create<AtenDivOp>(loc, sub, step);
|
||||
return rewriter.create<AtenCeilFloatOp>(loc, div);
|
||||
}
|
||||
|
||||
class ReduceNonValueSemanticOps : public RewritePattern {
|
||||
public:
|
||||
ReduceNonValueSemanticOps(MLIRContext *context)
|
||||
|
@ -198,19 +208,54 @@ public:
|
|||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Operation *newOp;
|
||||
MLIRContext *ctx = op->getContext();
|
||||
if (isa<AtenBernoulli_FloatOp>(op)) {
|
||||
newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
||||
Operation *newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
|
||||
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
return success();
|
||||
} else if (auto arangeOutOp = dyn_cast<AtenArangeStartOutOp>(op)) {
|
||||
Value start = arangeOutOp.getStart();
|
||||
Value end = arangeOutOp.getEnd();
|
||||
Value step = arangeOutOp.getStep();
|
||||
Value out = arangeOutOp.getOut();
|
||||
|
||||
// `overwrite.tensor.contents` cannot change the tensor shape,
|
||||
// so `out` tensor should have same num_elements with result tensor.
|
||||
// It means that we don't support code like:
|
||||
// `x = torch.randn(12)`
|
||||
// `y = torch.arange(13, out=x)`
|
||||
Value resultNumElements =
|
||||
calculateArangeResultNumElements(rewriter, loc, start, end, step);
|
||||
Value outNumElements = rewriter.create<AtenNumelOp>(loc, out);
|
||||
Value eqOrNot =
|
||||
rewriter.create<AtenEqIntOp>(loc, resultNumElements, outNumElements);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, eqOrNot,
|
||||
rewriter.getStringAttr("`out` tensor should have the same "
|
||||
"num_elements with result tenosr"));
|
||||
|
||||
auto dtype = rewriter.create<PrimDtypeOp>(loc, out);
|
||||
auto device = rewriter.create<PrimDeviceOp>(loc, out);
|
||||
auto shape = rewriter.create<AtenSizeOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(ctx)), out);
|
||||
auto none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value newArange = rewriter.create<AtenArangeStartStepOp>(
|
||||
loc, arangeOutOp.getResult().getType(), start, end, step, dtype,
|
||||
/*layout=*/none, device, /*pin_memory=*/none);
|
||||
Value reshape = rewriter.create<AtenReshapeOp>(
|
||||
loc, arangeOutOp.getResult().getType(), newArange, shape);
|
||||
|
||||
auto vtensor = rewriter.create<CopyToValueTensorOp>(loc, reshape);
|
||||
createOverwriteTensorContents(rewriter, loc, vtensor, out);
|
||||
rewriter.replaceOp(arangeOutOp, out);
|
||||
return success();
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
|
||||
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -309,6 +354,7 @@ struct ReduceOpVariantsPass
|
|||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
target.addIllegalOp<AtenArangeStartOutOp>();
|
||||
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
|
||||
Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
||||
|
|
|
@ -302,6 +302,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
|
||||
# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
|
||||
"ThresholdBackward2dMixedModule_basic",
|
||||
|
||||
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
||||
"ArangeStartOutViewModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||
|
@ -1303,6 +1306,8 @@ TOSA_PASS_SET = {
|
|||
"AtenEyeModuleFalsePinMemory_basic",
|
||||
"AtenEyeModuleFloat2D_basic",
|
||||
"MeanModule_basic",
|
||||
"ArangeStartOutModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||
|
@ -1372,6 +1377,7 @@ LTC_XFAIL_SET = {
|
|||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||
"AddIntModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
"BernoulliTensorModule_basic",
|
||||
"BincountMinlengthModule_basic",
|
||||
|
|
|
@ -248,3 +248,53 @@ class ArangeFalsePinMemoryModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
|
||||
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ArangeStartOutModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([12], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.arange(start=0, end=12, out=x)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartOutModule())
|
||||
def ArangeStartOutModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(12).to(torch.int64))
|
||||
|
||||
class ArangeStartOutViewModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.arange(start=1, end=13, out=x)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartOutViewModule())
|
||||
def ArangeStartOutViewModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(3, 4).to(torch.int64))
|
||||
|
||||
class ArangeStartOutDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([12], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.arange(start=1.1, end=13.1, out=x)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
|
||||
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(12).to(torch.int64))
|
||||
|
|
Loading…
Reference in New Issue