[Torch] Support AtenDivTensorModeOp with static int input for linalg and stablehlo backend (#3088)

pull/3097/head
Xinyu Yang 2024-04-02 17:28:53 +08:00 committed by GitHub
parent d2432bbe5a
commit ac1cd3d78a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 140 additions and 14 deletions

View File

@ -774,13 +774,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(divTensorMode.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
divTensorMode.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value div = b.create<arith::DivFOp>(loc, lhs, rhs);
Value div;
if (dtype.isa<mlir::FloatType>())
div = b.create<arith::DivFOp>(loc, lhs, rhs);
else {
if (dtype.isUnsignedInteger())
div = b.create<arith::DivUIOp>(loc, lhs, rhs);
else
div = b.create<arith::DivSIOp>(loc, lhs, rhs);
}
if (divTensorMode.getRoundingMode().getType().isa<Torch::NoneType>())
return div;
@ -794,17 +798,32 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
if (dtype.isa<mlir::FloatType>()) {
Value ceil = b.create<math::CeilOp>(loc, div);
Value floor = b.create<math::FloorOp>(loc, div);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value pred =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, div, cstZero);
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
div, cstZero);
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
} else
return div;
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (dtype.isa<mlir::FloatType>())
return b.create<math::FloorOp>(loc, div);
else if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type();
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
div = b.create<arith::DivFOp>(loc, lhs, rhs);
Value floor = b.create<math::FloorOp>(loc, div);
Value convert = convertScalarToDtype(b, loc, floor, dtype);
return convert;
} else {
return div;
}
}
divTensorMode.emitError("invalid rounding mode");
return nullptr;

View File

@ -445,7 +445,8 @@ public:
return rewriter.notifyMatchFailure(
op, "only support constant str rounding mode");
if (roundingMode == "trunc") {
// if trunc and int, do nothing
if (roundingMode == "trunc" && outElemTy.isa<mlir::FloatType>()) {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
@ -456,7 +457,20 @@ public:
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (outElemTy.isa<mlir::FloatType>())
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
else if (!outElemTy.isUnsignedInteger()) {
TensorType defaultIntToFloatType =
outType.cloneWith(outType.getShape(), rewriter.getF64Type());
lhs =
hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType);
rhs =
hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType);
result = rewriter.create<ChloOpT>(loc, defaultIntToFloatType, lhs, rhs,
bcastDimensions);
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
result = hlo::promoteType(rewriter, op.getLoc(), result, outType);
}
}
rewriter.replaceOp(op, result);
return success();

View File

@ -245,6 +245,10 @@ TORCHDYNAMO_XFAIL_SET = {
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
"ElementwiseDivRoundingModeFloorModule_basic",
"ElementwiseDivRoundingModeTruncModule_basic",
"ElementwiseDivRoundingModeFloorStaticModule_basic",
"ElementwiseDivRoundingModeTruncStaticModule_basic",
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
@ -479,6 +483,10 @@ STABLEHLO_PASS_SET = {
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseDivRoundingModeFloorStaticModule_basic",
"ElementwiseDivRoundingModeTruncStaticModule_basic",
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
"ElementwiseErfModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseFloorIntModule_basic",
@ -2024,7 +2032,10 @@ ONNX_XFAIL_SET = {
"ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncModule_basic",
"ElementwiseDivRoundingModeTruncStaticModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseLogIntModule_basic",

View File

@ -2789,6 +2789,88 @@ class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4], torch.float32, True),
([4], torch.float64, True),
])
def forward(self, a, b):
return torch.div(a, b, rounding_mode="trunc")
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeTruncStaticModule())
def ElementwiseDivRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.float32, True),
([3, 4], torch.float64, True),
])
def forward(self, a, b):
return torch.div(a, b, rounding_mode="floor")
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeFloorStaticModule())
def ElementwiseDivRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int32, True),
([3, 4], torch.int64, True),
])
def forward(self, a, b):
return torch.div(a, b, rounding_mode="trunc")
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeTruncIntStaticModule())
def ElementwiseDivRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int32, True),
([3, 4], torch.int64, True),
])
def forward(self, a, b):
return torch.div(a, b, rounding_mode="floor")
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeFloorIntStaticModule())
def ElementwiseDivRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
# ==============================================================================