diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 31f1a723f..4b2f80612 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1506,6 +1506,48 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return value; } + if (auto isClose = dyn_cast(op)) { + double rtol, atol; + bool equalNan; + if (!matchPattern(isClose.getRtol(), m_TorchConstantFloat(&rtol))) { + isClose.emitError("rtol must be a scalar constant"); + return nullptr; + } + if (!matchPattern(isClose.getAtol(), m_TorchConstantFloat(&atol))) { + isClose.emitError("atol must be a scalar constant"); + return nullptr; + } + if (!matchPattern(isClose.getEqualNan(), m_TorchConstantBool(&equalNan))) { + isClose.emitError("unimplemented: equal_nan is expected to be false"); + return nullptr; + } + auto lhsType = mlir::dyn_cast(payloadArgs[0].getType()); + auto rhsType = mlir::dyn_cast(payloadArgs[1].getType()); + if (!lhsType || !rhsType) { + isClose.emitError("unimplemented: only FP element type is supported"); + return nullptr; + } + // Choose the widest float type as compute type. + auto computeType = + lhsType.getWidth() > rhsType.getWidth() ? lhsType : rhsType; + computeType = computeType.getWidth() >= 32 ? computeType : b.getF32Type(); + auto cvtArg0 = convertScalarToDtype(b, loc, payloadArgs[0], computeType); + auto cvtArg1 = convertScalarToDtype(b, loc, payloadArgs[1], computeType); + // Reference to the definition of torch.isclose: + // ∣input − other∣ <= atol + rtol × ∣other∣ + auto diff = b.create(loc, computeType, cvtArg0, cvtArg1); + auto absDiff = b.create(loc, computeType, diff); + auto cstRtol = + b.create(loc, b.getFloatAttr(computeType, rtol)); + auto absOther = b.create(loc, computeType, cvtArg1); + auto mul = b.create(loc, computeType, cstRtol, absOther); + auto cstAtol = + b.create(loc, b.getFloatAttr(computeType, atol)); + auto threshold = b.create(loc, computeType, cstAtol, mul); + return b.create(loc, arith::CmpFPredicate::ULE, absDiff, + threshold); + } + op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; @@ -1564,7 +1606,7 @@ public: AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, - AtenQuantizePerTensorOp>(op)) + AtenQuantizePerTensorOp, AtenIscloseOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -3256,7 +3298,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, - AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); + AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 25e7b98ca..044c8154f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -16,8 +16,6 @@ from torch_mlir._version import torch_version_for_comparison, version print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison()) LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { - "IscloseStaticModule_basic", - "IscloseStaticModuleTrue_basic", # lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec # these interpolate tests are added specifically to test onnx.Resize. "InterpolateDynamicModule_sizes_bilinear",