mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] support AtenArangeStartOutOp in ReduceOpVariants like… (#2563)
… AtenBernoulli_FloatOp It fixing case like: `%2110 = torch.aten.arange.start_out %int1, %int1517, %int1, %2109 : !torch.int, !torch.int, !torch.int, !torch.tensor -> !torch.tensor`. `aten.arange.start_out` doesn't have value semantics also, means`%2110` is an alias for %2109. So I decompose it to `aten.arange.start` + `torch.contents.overwrite`. The complex decomposition logic is target to handle cases like view and dtype cast which I add in e2e tests.rollpytorch snapshot-20231117.1025
parent
dad1f012f6
commit
facbe5d96b
|
@ -191,6 +191,16 @@ private:
|
||||||
// Reduce Ops without value semantics but the corresponding without trailing
|
// Reduce Ops without value semantics but the corresponding without trailing
|
||||||
// underscore variant doesn't exist.
|
// underscore variant doesn't exist.
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// int(ceil((end - start) / step))
|
||||||
|
Value calculateArangeResultNumElements(PatternRewriter &rewriter, Location loc,
|
||||||
|
Value start, Value end, Value step) {
|
||||||
|
Value sub = rewriter.create<AtenSubOp>(
|
||||||
|
loc, Torch::NumberType::get(rewriter.getContext()), end, start);
|
||||||
|
Value div = rewriter.create<AtenDivOp>(loc, sub, step);
|
||||||
|
return rewriter.create<AtenCeilFloatOp>(loc, div);
|
||||||
|
}
|
||||||
|
|
||||||
class ReduceNonValueSemanticOps : public RewritePattern {
|
class ReduceNonValueSemanticOps : public RewritePattern {
|
||||||
public:
|
public:
|
||||||
ReduceNonValueSemanticOps(MLIRContext *context)
|
ReduceNonValueSemanticOps(MLIRContext *context)
|
||||||
|
@ -198,19 +208,54 @@ public:
|
||||||
LogicalResult matchAndRewrite(Operation *op,
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Operation *newOp;
|
MLIRContext *ctx = op->getContext();
|
||||||
if (isa<AtenBernoulli_FloatOp>(op)) {
|
if (isa<AtenBernoulli_FloatOp>(op)) {
|
||||||
newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
Operation *newOp = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
||||||
loc, op->getResultTypes(), op->getOperands());
|
loc, op->getResultTypes(), op->getOperands());
|
||||||
|
auto tensor =
|
||||||
|
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
|
||||||
|
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
|
||||||
|
rewriter.replaceOp(op, op->getOperand(0));
|
||||||
|
return success();
|
||||||
|
} else if (auto arangeOutOp = dyn_cast<AtenArangeStartOutOp>(op)) {
|
||||||
|
Value start = arangeOutOp.getStart();
|
||||||
|
Value end = arangeOutOp.getEnd();
|
||||||
|
Value step = arangeOutOp.getStep();
|
||||||
|
Value out = arangeOutOp.getOut();
|
||||||
|
|
||||||
|
// `overwrite.tensor.contents` cannot change the tensor shape,
|
||||||
|
// so `out` tensor should have same num_elements with result tensor.
|
||||||
|
// It means that we don't support code like:
|
||||||
|
// `x = torch.randn(12)`
|
||||||
|
// `y = torch.arange(13, out=x)`
|
||||||
|
Value resultNumElements =
|
||||||
|
calculateArangeResultNumElements(rewriter, loc, start, end, step);
|
||||||
|
Value outNumElements = rewriter.create<AtenNumelOp>(loc, out);
|
||||||
|
Value eqOrNot =
|
||||||
|
rewriter.create<AtenEqIntOp>(loc, resultNumElements, outNumElements);
|
||||||
|
rewriter.create<RuntimeAssertOp>(
|
||||||
|
loc, eqOrNot,
|
||||||
|
rewriter.getStringAttr("`out` tensor should have the same "
|
||||||
|
"num_elements with result tenosr"));
|
||||||
|
|
||||||
|
auto dtype = rewriter.create<PrimDtypeOp>(loc, out);
|
||||||
|
auto device = rewriter.create<PrimDeviceOp>(loc, out);
|
||||||
|
auto shape = rewriter.create<AtenSizeOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(ctx)), out);
|
||||||
|
auto none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value newArange = rewriter.create<AtenArangeStartStepOp>(
|
||||||
|
loc, arangeOutOp.getResult().getType(), start, end, step, dtype,
|
||||||
|
/*layout=*/none, device, /*pin_memory=*/none);
|
||||||
|
Value reshape = rewriter.create<AtenReshapeOp>(
|
||||||
|
loc, arangeOutOp.getResult().getType(), newArange, shape);
|
||||||
|
|
||||||
|
auto vtensor = rewriter.create<CopyToValueTensorOp>(loc, reshape);
|
||||||
|
createOverwriteTensorContents(rewriter, loc, vtensor, out);
|
||||||
|
rewriter.replaceOp(arangeOutOp, out);
|
||||||
|
return success();
|
||||||
} else {
|
} else {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tensor =
|
|
||||||
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
|
|
||||||
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
|
|
||||||
rewriter.replaceOp(op, op->getOperand(0));
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -309,6 +354,7 @@ struct ReduceOpVariantsPass
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||||
|
target.addIllegalOp<AtenArangeStartOutOp>();
|
||||||
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
|
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
|
||||||
Operation *op) {
|
Operation *op) {
|
||||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
||||||
|
|
|
@ -302,6 +302,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
|
|
||||||
# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
|
# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
|
||||||
"ThresholdBackward2dMixedModule_basic",
|
"ThresholdBackward2dMixedModule_basic",
|
||||||
|
|
||||||
|
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
||||||
|
"ArangeStartOutViewModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||||
|
@ -1303,6 +1306,8 @@ TOSA_PASS_SET = {
|
||||||
"AtenEyeModuleFalsePinMemory_basic",
|
"AtenEyeModuleFalsePinMemory_basic",
|
||||||
"AtenEyeModuleFloat2D_basic",
|
"AtenEyeModuleFloat2D_basic",
|
||||||
"MeanModule_basic",
|
"MeanModule_basic",
|
||||||
|
"ArangeStartOutModule_basic",
|
||||||
|
"ArangeStartOutViewModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||||
|
@ -1372,6 +1377,7 @@ LTC_XFAIL_SET = {
|
||||||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||||
"AddIntModule_basic",
|
"AddIntModule_basic",
|
||||||
|
"ArangeStartOutViewModule_basic",
|
||||||
"AtenIntBoolOpModule_basic",
|
"AtenIntBoolOpModule_basic",
|
||||||
"BernoulliTensorModule_basic",
|
"BernoulliTensorModule_basic",
|
||||||
"BincountMinlengthModule_basic",
|
"BincountMinlengthModule_basic",
|
||||||
|
|
|
@ -248,3 +248,53 @@ class ArangeFalsePinMemoryModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
|
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
|
||||||
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ArangeStartOutModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([12], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.arange(start=0, end=12, out=x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ArangeStartOutModule())
|
||||||
|
def ArangeStartOutModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.zeros(12).to(torch.int64))
|
||||||
|
|
||||||
|
class ArangeStartOutViewModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.arange(start=1, end=13, out=x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ArangeStartOutViewModule())
|
||||||
|
def ArangeStartOutViewModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.zeros(3, 4).to(torch.int64))
|
||||||
|
|
||||||
|
class ArangeStartOutDtypeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([12], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.arange(start=1.1, end=13.1, out=x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
|
||||||
|
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.zeros(12).to(torch.int64))
|
||||||
|
|
Loading…
Reference in New Issue