mirror of https://github.com/llvm/torch-mlir
[Torch] Support AtenDivTensorModeOp with static int input for linalg and stablehlo backend (#3088)
parent
d2432bbe5a
commit
ac1cd3d78a
|
@ -774,13 +774,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Type dtype = converter->convertType(divTensorMode.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
divTensorMode.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
Value div;
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
else {
|
||||
if (dtype.isUnsignedInteger())
|
||||
div = b.create<arith::DivUIOp>(loc, lhs, rhs);
|
||||
else
|
||||
div = b.create<arith::DivSIOp>(loc, lhs, rhs);
|
||||
}
|
||||
|
||||
if (divTensorMode.getRoundingMode().getType().isa<Torch::NoneType>())
|
||||
return div;
|
||||
|
@ -794,17 +798,32 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
if (roundingMode == "trunc") {
|
||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||
// to C-style integer division.
|
||||
Value ceil = b.create<math::CeilOp>(loc, div);
|
||||
Value floor = b.create<math::FloorOp>(loc, div);
|
||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||
Value pred =
|
||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, div, cstZero);
|
||||
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
Value ceil = b.create<math::CeilOp>(loc, div);
|
||||
Value floor = b.create<math::FloorOp>(loc, div);
|
||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
div, cstZero);
|
||||
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
|
||||
} else
|
||||
return div;
|
||||
}
|
||||
if (roundingMode == "floor") {
|
||||
// "floor" - rounds the results of the division down. Equivalent to
|
||||
// floor division in Python (the // operator)
|
||||
return b.create<math::FloorOp>(loc, div);
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
return b.create<math::FloorOp>(loc, div);
|
||||
else if (!dtype.isUnsignedInteger()) {
|
||||
Type defaultIntToFloatType = b.getF64Type();
|
||||
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
|
||||
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
|
||||
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
Value floor = b.create<math::FloorOp>(loc, div);
|
||||
Value convert = convertScalarToDtype(b, loc, floor, dtype);
|
||||
return convert;
|
||||
} else {
|
||||
return div;
|
||||
}
|
||||
}
|
||||
divTensorMode.emitError("invalid rounding mode");
|
||||
return nullptr;
|
||||
|
|
|
@ -445,7 +445,8 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant str rounding mode");
|
||||
|
||||
if (roundingMode == "trunc") {
|
||||
// if trunc and int, do nothing
|
||||
if (roundingMode == "trunc" && outElemTy.isa<mlir::FloatType>()) {
|
||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||
// to C-style integer division.
|
||||
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
|
||||
|
@ -456,7 +457,20 @@ public:
|
|||
if (roundingMode == "floor") {
|
||||
// "floor" - rounds the results of the division down. Equivalent to
|
||||
// floor division in Python (the // operator)
|
||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||
if (outElemTy.isa<mlir::FloatType>())
|
||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||
else if (!outElemTy.isUnsignedInteger()) {
|
||||
TensorType defaultIntToFloatType =
|
||||
outType.cloneWith(outType.getShape(), rewriter.getF64Type());
|
||||
lhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType);
|
||||
rhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType);
|
||||
result = rewriter.create<ChloOpT>(loc, defaultIntToFloatType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||
result = hlo::promoteType(rewriter, op.getLoc(), result, outType);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
|
|
|
@ -245,6 +245,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
||||
"ElementwiseDivRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
|
@ -479,6 +483,10 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseCloneContiguousModule_basic",
|
||||
"ElementwiseCloneModule_basic",
|
||||
"ElementwiseCosModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseErfModule_basic",
|
||||
"ElementwiseExpModule_basic",
|
||||
"ElementwiseFloorIntModule_basic",
|
||||
|
@ -2024,7 +2032,10 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAtanTensorIntModule_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseExpIntModule_basic",
|
||||
"ElementwiseLogIntModule_basic",
|
||||
|
|
|
@ -2789,6 +2789,88 @@ class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
|
|||
def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
||||
|
||||
class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4], torch.float32, True),
|
||||
([4], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b, rounding_mode="trunc")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeTruncStaticModule())
|
||||
def ElementwiseDivRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||
|
||||
|
||||
class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.float32, True),
|
||||
([3, 4], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b, rounding_mode="floor")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeFloorStaticModule())
|
||||
def ElementwiseDivRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
||||
|
||||
class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.int32, True),
|
||||
([3, 4], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b, rounding_mode="trunc")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeTruncIntStaticModule())
|
||||
def ElementwiseDivRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
||||
|
||||
|
||||
class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.int32, True),
|
||||
([3, 4], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b, rounding_mode="floor")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeFloorIntStaticModule())
|
||||
def ElementwiseDivRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue