mirror of https://github.com/llvm/torch-mlir
[Torch] Support AtenDivTensorModeOp with static int input for linalg and stablehlo backend (#3088)
parent
d2432bbe5a
commit
ac1cd3d78a
|
@ -774,13 +774,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(divTensorMode.getType())
|
Type dtype = converter->convertType(divTensorMode.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
|
||||||
divTensorMode.emitError("unimplemented: non-floating point dtype");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
Value div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
Value div;
|
||||||
|
if (dtype.isa<mlir::FloatType>())
|
||||||
|
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||||
|
else {
|
||||||
|
if (dtype.isUnsignedInteger())
|
||||||
|
div = b.create<arith::DivUIOp>(loc, lhs, rhs);
|
||||||
|
else
|
||||||
|
div = b.create<arith::DivSIOp>(loc, lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
if (divTensorMode.getRoundingMode().getType().isa<Torch::NoneType>())
|
if (divTensorMode.getRoundingMode().getType().isa<Torch::NoneType>())
|
||||||
return div;
|
return div;
|
||||||
|
@ -794,17 +798,32 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
if (roundingMode == "trunc") {
|
if (roundingMode == "trunc") {
|
||||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||||
// to C-style integer division.
|
// to C-style integer division.
|
||||||
Value ceil = b.create<math::CeilOp>(loc, div);
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
Value floor = b.create<math::FloorOp>(loc, div);
|
Value ceil = b.create<math::CeilOp>(loc, div);
|
||||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
Value floor = b.create<math::FloorOp>(loc, div);
|
||||||
Value pred =
|
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, div, cstZero);
|
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||||
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
|
div, cstZero);
|
||||||
|
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
|
||||||
|
} else
|
||||||
|
return div;
|
||||||
}
|
}
|
||||||
if (roundingMode == "floor") {
|
if (roundingMode == "floor") {
|
||||||
// "floor" - rounds the results of the division down. Equivalent to
|
// "floor" - rounds the results of the division down. Equivalent to
|
||||||
// floor division in Python (the // operator)
|
// floor division in Python (the // operator)
|
||||||
return b.create<math::FloorOp>(loc, div);
|
if (dtype.isa<mlir::FloatType>())
|
||||||
|
return b.create<math::FloorOp>(loc, div);
|
||||||
|
else if (!dtype.isUnsignedInteger()) {
|
||||||
|
Type defaultIntToFloatType = b.getF64Type();
|
||||||
|
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
|
||||||
|
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
|
||||||
|
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||||
|
Value floor = b.create<math::FloorOp>(loc, div);
|
||||||
|
Value convert = convertScalarToDtype(b, loc, floor, dtype);
|
||||||
|
return convert;
|
||||||
|
} else {
|
||||||
|
return div;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
divTensorMode.emitError("invalid rounding mode");
|
divTensorMode.emitError("invalid rounding mode");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -445,7 +445,8 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support constant str rounding mode");
|
op, "only support constant str rounding mode");
|
||||||
|
|
||||||
if (roundingMode == "trunc") {
|
// if trunc and int, do nothing
|
||||||
|
if (roundingMode == "trunc" && outElemTy.isa<mlir::FloatType>()) {
|
||||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||||
// to C-style integer division.
|
// to C-style integer division.
|
||||||
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
|
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
|
||||||
|
@ -456,7 +457,20 @@ public:
|
||||||
if (roundingMode == "floor") {
|
if (roundingMode == "floor") {
|
||||||
// "floor" - rounds the results of the division down. Equivalent to
|
// "floor" - rounds the results of the division down. Equivalent to
|
||||||
// floor division in Python (the // operator)
|
// floor division in Python (the // operator)
|
||||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
if (outElemTy.isa<mlir::FloatType>())
|
||||||
|
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||||
|
else if (!outElemTy.isUnsignedInteger()) {
|
||||||
|
TensorType defaultIntToFloatType =
|
||||||
|
outType.cloneWith(outType.getShape(), rewriter.getF64Type());
|
||||||
|
lhs =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType);
|
||||||
|
rhs =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType);
|
||||||
|
result = rewriter.create<ChloOpT>(loc, defaultIntToFloatType, lhs, rhs,
|
||||||
|
bcastDimensions);
|
||||||
|
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||||
|
result = hlo::promoteType(rewriter, op.getLoc(), result, outType);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -245,6 +245,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
||||||
"ElementwiseDivRoundingModeFloorModule_basic",
|
"ElementwiseDivRoundingModeFloorModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeFloorStaticModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
||||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
|
@ -479,6 +483,10 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseCloneContiguousModule_basic",
|
"ElementwiseCloneContiguousModule_basic",
|
||||||
"ElementwiseCloneModule_basic",
|
"ElementwiseCloneModule_basic",
|
||||||
"ElementwiseCosModule_basic",
|
"ElementwiseCosModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeFloorStaticModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
||||||
"ElementwiseErfModule_basic",
|
"ElementwiseErfModule_basic",
|
||||||
"ElementwiseExpModule_basic",
|
"ElementwiseExpModule_basic",
|
||||||
"ElementwiseFloorIntModule_basic",
|
"ElementwiseFloorIntModule_basic",
|
||||||
|
@ -2024,7 +2032,10 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseAsinIntModule_basic",
|
"ElementwiseAsinIntModule_basic",
|
||||||
"ElementwiseAtanTensorIntModule_basic",
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
"ElementwiseCosIntModule_basic",
|
"ElementwiseCosIntModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
||||||
"ElementwiseErfIntModule_basic",
|
"ElementwiseErfIntModule_basic",
|
||||||
"ElementwiseExpIntModule_basic",
|
"ElementwiseExpIntModule_basic",
|
||||||
"ElementwiseLogIntModule_basic",
|
"ElementwiseLogIntModule_basic",
|
||||||
|
|
|
@ -2789,6 +2789,88 @@ class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
|
||||||
def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils):
|
def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
||||||
|
|
||||||
|
class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([4], torch.float32, True),
|
||||||
|
([4], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.div(a, b, rounding_mode="trunc")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivRoundingModeTruncStaticModule())
|
||||||
|
def ElementwiseDivRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4], torch.float32, True),
|
||||||
|
([3, 4], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.div(a, b, rounding_mode="floor")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivRoundingModeFloorStaticModule())
|
||||||
|
def ElementwiseDivRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
||||||
|
|
||||||
|
class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4], torch.int32, True),
|
||||||
|
([3, 4], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.div(a, b, rounding_mode="trunc")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivRoundingModeTruncIntStaticModule())
|
||||||
|
def ElementwiseDivRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4], torch.int32, True),
|
||||||
|
([3, 4], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.div(a, b, rounding_mode="floor")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivRoundingModeFloorIntStaticModule())
|
||||||
|
def ElementwiseDivRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue