mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.div.Scalar
This commit adds lowering of `aten.div.Scalar`. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/439/head snapshot-20211124.103
parent
56c6e3676b
commit
8d8d2c2fb8
|
@ -479,3 +479,21 @@ class ElementwiseRsqrtModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseRsqrtModule())
|
||||
def ElementwiseRsqrtModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
class ElementwiseDivScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.div(x, 10.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseDivScalarModule())
|
||||
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
|
|
@ -1594,6 +1594,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value result = convertScalarToDtype(b, loc, input, dtype);
|
||||
return result;
|
||||
}
|
||||
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
|
||||
Type dtype = converter->convertType(divScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
divScalar.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value self = payloadArgs[0];
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
return b.create<arith::DivFOp>(loc, self, other);
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForElementwiseOp");
|
||||
|
@ -1805,7 +1817,8 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
||||
AtenMulScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp>(op))
|
||||
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp>(
|
||||
op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
|
Loading…
Reference in New Issue