mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Support torch.isclose lower to linalg (#3631)
parent
a24114efa3
commit
7f886cc270
|
@ -1506,6 +1506,48 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return value;
|
||||
}
|
||||
|
||||
if (auto isClose = dyn_cast<AtenIscloseOp>(op)) {
|
||||
double rtol, atol;
|
||||
bool equalNan;
|
||||
if (!matchPattern(isClose.getRtol(), m_TorchConstantFloat(&rtol))) {
|
||||
isClose.emitError("rtol must be a scalar constant");
|
||||
return nullptr;
|
||||
}
|
||||
if (!matchPattern(isClose.getAtol(), m_TorchConstantFloat(&atol))) {
|
||||
isClose.emitError("atol must be a scalar constant");
|
||||
return nullptr;
|
||||
}
|
||||
if (!matchPattern(isClose.getEqualNan(), m_TorchConstantBool(&equalNan))) {
|
||||
isClose.emitError("unimplemented: equal_nan is expected to be false");
|
||||
return nullptr;
|
||||
}
|
||||
auto lhsType = mlir::dyn_cast<mlir::FloatType>(payloadArgs[0].getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::FloatType>(payloadArgs[1].getType());
|
||||
if (!lhsType || !rhsType) {
|
||||
isClose.emitError("unimplemented: only FP element type is supported");
|
||||
return nullptr;
|
||||
}
|
||||
// Choose the widest float type as compute type.
|
||||
auto computeType =
|
||||
lhsType.getWidth() > rhsType.getWidth() ? lhsType : rhsType;
|
||||
computeType = computeType.getWidth() >= 32 ? computeType : b.getF32Type();
|
||||
auto cvtArg0 = convertScalarToDtype(b, loc, payloadArgs[0], computeType);
|
||||
auto cvtArg1 = convertScalarToDtype(b, loc, payloadArgs[1], computeType);
|
||||
// Reference to the definition of torch.isclose:
|
||||
// ∣input − other∣ <= atol + rtol × ∣other∣
|
||||
auto diff = b.create<arith::SubFOp>(loc, computeType, cvtArg0, cvtArg1);
|
||||
auto absDiff = b.create<math::AbsFOp>(loc, computeType, diff);
|
||||
auto cstRtol =
|
||||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(computeType, rtol));
|
||||
auto absOther = b.create<math::AbsFOp>(loc, computeType, cvtArg1);
|
||||
auto mul = b.create<arith::MulFOp>(loc, computeType, cstRtol, absOther);
|
||||
auto cstAtol =
|
||||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(computeType, atol));
|
||||
auto threshold = b.create<arith::AddFOp>(loc, computeType, cstAtol, mul);
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, absDiff,
|
||||
threshold);
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForElementwiseOp");
|
||||
return nullptr;
|
||||
|
@ -1564,7 +1606,7 @@ public:
|
|||
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
||||
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
||||
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenQuantizePerTensorOp>(op))
|
||||
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -3256,7 +3298,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
|
||||
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
|
||||
AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
|
||||
AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -16,8 +16,6 @@ from torch_mlir._version import torch_version_for_comparison, version
|
|||
print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison())
|
||||
|
||||
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||
"IscloseStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
|
||||
# these interpolate tests are added specifically to test onnx.Resize.
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
|
|
Loading…
Reference in New Issue