mirror of https://github.com/llvm/torch-mlir
lower torch.aten.isinf to linalg (#2638)
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>pull/2708/head snapshot-20231229.1067
parent
9fc212ea9a
commit
6660a26594
|
@ -426,6 +426,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (isa<AtenAbsOp>(op))
|
||||
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenIsinfOp>(op)){
|
||||
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
||||
Value infinity = b.create<arith::ConstantOp>(
|
||||
loc, b.getFloatAttr(abs.getType(), std::numeric_limits<double>::infinity()));
|
||||
return createEqual(b, loc, abs.getType(), abs, infinity);
|
||||
}
|
||||
if (isa<AtenSigmoidOp>(op)) {
|
||||
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
|
@ -1343,7 +1349,7 @@ public:
|
|||
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
|
||||
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||
AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
@ -1992,7 +1998,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
|
||||
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
|
||||
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
|
||||
AtenTrilOp,
|
||||
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
|
|
|
@ -39,6 +39,10 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
|
|||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
|
||||
return {Shape(at::kBool, self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
|
||||
const at::Tensor& self, at::IntArrayRef kernel_size,
|
||||
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,
|
||||
|
|
|
@ -1033,6 +1033,7 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseAddScalarIntModule_basic",
|
||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
"ElementwiseAtenDivIntScalarModule_basic",
|
||||
"ElementwiseAtenIsinfOpModule_basic",
|
||||
"ElementwiseAtenWhereSelfModule_basic",
|
||||
"ElementwiseBinaryModule_basic",
|
||||
"ElementwiseBinaryStaticShapeModule_basic",
|
||||
|
@ -1328,6 +1329,8 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
"SliceWholeTensorModule_basic",
|
||||
"TensorFloatModule_basic",
|
||||
"TensorIntModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
}) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
|
||||
|
@ -1489,5 +1492,4 @@ LTC_XFAIL_SET = {
|
|||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"ElementwiseIsinfModule_basic",
|
||||
}
|
||||
|
|
|
@ -3385,6 +3385,31 @@ def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.randint(4, 5, high=2).bool())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseAtenIsinfOpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.isinf(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenIsinfOpModule())
|
||||
def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils):
|
||||
test_input = torch.tensor(
|
||||
[
|
||||
[1, float('inf'), 2, float('-inf'), float('nan')],
|
||||
[1, float('inf'), float('-inf'), float('nan'), 3],
|
||||
]
|
||||
)
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue