[MLIR][TORCH] Add E2E support for aten.div.Scalar

This commit adds lowering of `aten.div.Scalar`.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/439/head snapshot-20211124.103
Vivek Khandelwal 2021-11-23 12:24:31 +05:30
parent 56c6e3676b
commit 8d8d2c2fb8
2 changed files with 32 additions and 1 deletions

View File

@ -479,3 +479,21 @@ class ElementwiseRsqrtModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseRsqrtModule())
def ElementwiseRsqrtModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseDivScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.div(x, 10.0)
@register_test_case(module_factory=lambda: ElementwiseDivScalarModule())
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

View File

@ -1594,6 +1594,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value result = convertScalarToDtype(b, loc, input, dtype);
return result;
}
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
Type dtype = converter->convertType(divScalar.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
divScalar.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value self = payloadArgs[0];
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<arith::DivFOp>(loc, self, other);
}
op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
@ -1805,7 +1817,8 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
AtenMulScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp>(op))
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp>(
op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))