Implement lowering of torch.aten.fmod.Tensor (#2767)

Closing https://github.com/nod-ai/SHARK-Turbine/issues/351
pull/2954/head
mmakevic 2024-02-29 06:52:03 +01:00 committed by GitHub
parent f21b76b68a
commit 76b81e0ccd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 142 additions and 30 deletions

View File

@ -882,14 +882,14 @@ public:
if (bias.getType().isa<Torch::NoneType>()) {
Value c0;
if (resultDTy.isa<mlir::FloatType>()) {
c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(resultDTy, 0.0));
c0 = rewriter.create<arith::ConstantOp>(loc,
FloatAttr::get(resultDTy, 0.0));
} else if (resultDTy.isa<mlir::IntegerType>()) {
c0 = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(resultDTy, 0));
c0 = rewriter.create<arith::ConstantOp>(loc,
IntegerAttr::get(resultDTy, 0));
}
outputTensor = rewriter.create<linalg::FillOp>(loc, c0, initTensor)
.getResult(0);
outputTensor =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
} else {
auto biasType = bias.getType().cast<RankedTensorType>();
@ -1058,11 +1058,11 @@ public:
loc, collapsedType, weight, collapsedDims);
conv = rewriter
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
.create<linalg::DepthwiseConv2DNchwChwOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, collapsedWeight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);

View File

@ -1274,6 +1274,29 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return result;
}
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
Type newResultType = converter->convertType(fmod.getType())
.cast<RankedTensorType>()
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result;
if (newResultType.isa<mlir::FloatType>()) {
Value n = b.create<arith::DivFOp>(loc, self, other);
n = b.create<math::TruncOp>(loc, n);
Value n_y = b.create<arith::MulFOp>(loc, n, other);
result = b.create<arith::SubFOp>(loc, self, n_y);
} else if (newResultType.isa<mlir::IntegerType>()) {
Value n = b.create<arith::DivSIOp>(loc, self, other);
Value n_y = b.create<arith::MulIOp>(loc, n, other);
result = b.create<arith::SubIOp>(loc, self, n_y);
} else {
fmod.emitError("Unsupported type encountered for AtenFmodTensorOp.");
}
return result;
}
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
Type dtype = converter->convertType(reciprocal.getType())
.cast<RankedTensorType>()
@ -1541,22 +1564,22 @@ public:
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp,
AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp,
AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp,
AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
@ -2584,9 +2607,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -6865,6 +6865,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fmod.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.floor_divide.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -11395,6 +11399,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"

View File

@ -1641,6 +1641,9 @@ ONNX_XFAIL_SET = {
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseRemainderTensorModule_Int_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
"EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",
"EqIntModule_basic",

View File

@ -438,6 +438,9 @@ def atenremainderScalar〡shape(self: List[int], other: float) -> List[int
def atenremainderTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenfmodTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenfloor_divideScalar〡shape(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self)
@ -3491,6 +3494,14 @@ def atenfmodScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def atenfmodTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))

View File

@ -2526,6 +2526,68 @@ def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseFmodTensor_Float(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float32, True)
])
def forward(self, x, y):
return torch.fmod(x, y)
@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Float())
def ElementwiseFmodTensor_Float_basic(module, tu: TestUtils):
module.forward(tu.rand(100, low=-10, high=10), tu.rand(100, low=-10, high=10))
# ==============================================================================
class ElementwiseFmodTensor_Int_Float(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int32, True),
([-1], torch.float32, True)
])
def forward(self, x, y):
return torch.fmod(x, y)
@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int_Float())
def ElementwiseFmodTensor_Int_Float_basic(module, tu: TestUtils):
module.forward(tu.randint(100, low=-10, high=10).to(torch.int32), tu.rand(100, low=-10, high=10))
# ==============================================================================
class ElementwiseFmodTensor_Int(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int32, True),
([-1], torch.int32, True),
])
def forward(self, x, y):
return torch.fmod(x, y)
@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int())
def ElementwiseFmodTensor_Int_basic(module, tu: TestUtils):
module.forward(tu.randint(100, low=0, high=1000).to(torch.int32), tu.randint(100, low=1, high=1000).to(torch.int32))
# ==============================================================================
class ElementwiseRemainderTensorModule_Int_Float(torch.nn.Module):