[TorchToLinalg] Support torch.isclose lower to linalg (#3631)

pull/3657/head
lingzhiz1998 2024-08-21 11:55:54 +08:00 committed by GitHub
parent a24114efa3
commit 7f886cc270
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 4 deletions

View File

@ -1506,6 +1506,48 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return value; return value;
} }
if (auto isClose = dyn_cast<AtenIscloseOp>(op)) {
double rtol, atol;
bool equalNan;
if (!matchPattern(isClose.getRtol(), m_TorchConstantFloat(&rtol))) {
isClose.emitError("rtol must be a scalar constant");
return nullptr;
}
if (!matchPattern(isClose.getAtol(), m_TorchConstantFloat(&atol))) {
isClose.emitError("atol must be a scalar constant");
return nullptr;
}
if (!matchPattern(isClose.getEqualNan(), m_TorchConstantBool(&equalNan))) {
isClose.emitError("unimplemented: equal_nan is expected to be false");
return nullptr;
}
auto lhsType = mlir::dyn_cast<mlir::FloatType>(payloadArgs[0].getType());
auto rhsType = mlir::dyn_cast<mlir::FloatType>(payloadArgs[1].getType());
if (!lhsType || !rhsType) {
isClose.emitError("unimplemented: only FP element type is supported");
return nullptr;
}
// Choose the widest float type as compute type.
auto computeType =
lhsType.getWidth() > rhsType.getWidth() ? lhsType : rhsType;
computeType = computeType.getWidth() >= 32 ? computeType : b.getF32Type();
auto cvtArg0 = convertScalarToDtype(b, loc, payloadArgs[0], computeType);
auto cvtArg1 = convertScalarToDtype(b, loc, payloadArgs[1], computeType);
// Reference to the definition of torch.isclose:
// input other <= atol + rtol × other
auto diff = b.create<arith::SubFOp>(loc, computeType, cvtArg0, cvtArg1);
auto absDiff = b.create<math::AbsFOp>(loc, computeType, diff);
auto cstRtol =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(computeType, rtol));
auto absOther = b.create<math::AbsFOp>(loc, computeType, cvtArg1);
auto mul = b.create<arith::MulFOp>(loc, computeType, cstRtol, absOther);
auto cstAtol =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(computeType, atol));
auto threshold = b.create<arith::AddFOp>(loc, computeType, cstAtol, mul);
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, absDiff,
threshold);
}
op->emitError("unimplemented lowering in " op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp"); "createLinalgPayloadCalculationForElementwiseOp");
return nullptr; return nullptr;
@ -1564,7 +1606,7 @@ public:
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp>(op)) AtenQuantizePerTensorOp, AtenIscloseOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -3256,7 +3298,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp, AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context); patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -16,8 +16,6 @@ from torch_mlir._version import torch_version_for_comparison, version
print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison()) print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison())
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec # lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
# these interpolate tests are added specifically to test onnx.Resize. # these interpolate tests are added specifically to test onnx.Resize.
"InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_bilinear",