Implement lowering of torch.aten.remainder.Tensor (#2763)

Closes nod-ai/SHARK-Turbine#349
pull/2775/head
Ilija Kalinić 2024-01-19 13:39:08 +01:00 committed by GitHub
parent 4de4d38b87
commit faa4517e83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 118 additions and 6 deletions

View File

@ -1195,6 +1195,26 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result;
}
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
Type newResultType = converter->convertType(remTensor.getType())
.cast<RankedTensorType>()
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result;
if (newResultType.isa<mlir::FloatType>()) {
result = b.create<arith::RemFOp>(loc, self, other);
} else if (newResultType.isa<mlir::IntegerType>()) {
result = b.create<arith::RemSIOp>(loc, self, other);
} else {
remTensor.emitError(
"Unsupported type encountered for AtenRemainderTensorOp.");
}
return result;
}
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
Type dtype = converter->convertType(reciprocal.getType())
.cast<RankedTensorType>()
@ -1457,8 +1477,8 @@ public:
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp,
AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
@ -1471,7 +1491,8 @@ public:
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(op))
AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -2239,9 +2260,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -6758,6 +6758,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.remainder.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.floor_divide.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -10725,6 +10729,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"

View File

@ -383,6 +383,9 @@ def atendivScalar〡shape(self: List[int], other: float) -> List[int]:
def atenremainderScalar〡shape(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self)
def atenremainderTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenfloor_divideScalar〡shape(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self)
@ -3224,6 +3227,14 @@ def atenremainderScalar〡dtype(self_rank_dtype: Tuple[int, int], other: U
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def atenremainderTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor
@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool}) +

View File

@ -2265,6 +2265,73 @@ def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRemainderTensorModule_Int_Float(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.remainder(a, b)
@register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float())
def ElementwiseRemainderTensorModule_Int_Float_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=10).to(torch.int32), tu.rand(3, 4, high=10))
# ==============================================================================
class ElementwiseRemainderTensorModule_Float(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.remainder(a, b)
@register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Float())
def ElementwiseRemainderTensorModule_Float_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, high=10))
# ==============================================================================
class ElementwiseRemainderTensorModule_Int(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.int32, True),
])
def forward(self, a, b):
return torch.remainder(a, b)
@register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Int())
def ElementwiseRemainderTensorModule_Int_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=10, dtype=torch.int32), tu.randint(3, 4, high=10, dtype=torch.int32))
# ==============================================================================
class ElementwiseDivTensorFloatModule(torch.nn.Module):
def __init__(self):