mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Support torch.isclose lower to linalg (#3631)
parent
a24114efa3
commit
7f886cc270
|
@ -1506,6 +1506,48 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto isClose = dyn_cast<AtenIscloseOp>(op)) {
|
||||||
|
double rtol, atol;
|
||||||
|
bool equalNan;
|
||||||
|
if (!matchPattern(isClose.getRtol(), m_TorchConstantFloat(&rtol))) {
|
||||||
|
isClose.emitError("rtol must be a scalar constant");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!matchPattern(isClose.getAtol(), m_TorchConstantFloat(&atol))) {
|
||||||
|
isClose.emitError("atol must be a scalar constant");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!matchPattern(isClose.getEqualNan(), m_TorchConstantBool(&equalNan))) {
|
||||||
|
isClose.emitError("unimplemented: equal_nan is expected to be false");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto lhsType = mlir::dyn_cast<mlir::FloatType>(payloadArgs[0].getType());
|
||||||
|
auto rhsType = mlir::dyn_cast<mlir::FloatType>(payloadArgs[1].getType());
|
||||||
|
if (!lhsType || !rhsType) {
|
||||||
|
isClose.emitError("unimplemented: only FP element type is supported");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// Choose the widest float type as compute type.
|
||||||
|
auto computeType =
|
||||||
|
lhsType.getWidth() > rhsType.getWidth() ? lhsType : rhsType;
|
||||||
|
computeType = computeType.getWidth() >= 32 ? computeType : b.getF32Type();
|
||||||
|
auto cvtArg0 = convertScalarToDtype(b, loc, payloadArgs[0], computeType);
|
||||||
|
auto cvtArg1 = convertScalarToDtype(b, loc, payloadArgs[1], computeType);
|
||||||
|
// Reference to the definition of torch.isclose:
|
||||||
|
// ∣input − other∣ <= atol + rtol × ∣other∣
|
||||||
|
auto diff = b.create<arith::SubFOp>(loc, computeType, cvtArg0, cvtArg1);
|
||||||
|
auto absDiff = b.create<math::AbsFOp>(loc, computeType, diff);
|
||||||
|
auto cstRtol =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getFloatAttr(computeType, rtol));
|
||||||
|
auto absOther = b.create<math::AbsFOp>(loc, computeType, cvtArg1);
|
||||||
|
auto mul = b.create<arith::MulFOp>(loc, computeType, cstRtol, absOther);
|
||||||
|
auto cstAtol =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getFloatAttr(computeType, atol));
|
||||||
|
auto threshold = b.create<arith::AddFOp>(loc, computeType, cstAtol, mul);
|
||||||
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, absDiff,
|
||||||
|
threshold);
|
||||||
|
}
|
||||||
|
|
||||||
op->emitError("unimplemented lowering in "
|
op->emitError("unimplemented lowering in "
|
||||||
"createLinalgPayloadCalculationForElementwiseOp");
|
"createLinalgPayloadCalculationForElementwiseOp");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1564,7 +1606,7 @@ public:
|
||||||
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
||||||
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
||||||
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||||
AtenQuantizePerTensorOp>(op))
|
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -3256,7 +3298,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
|
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
|
||||||
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||||
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
|
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
|
||||||
AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
|
AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -16,8 +16,6 @@ from torch_mlir._version import torch_version_for_comparison, version
|
||||||
print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison())
|
print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison())
|
||||||
|
|
||||||
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||||
"IscloseStaticModule_basic",
|
|
||||||
"IscloseStaticModuleTrue_basic",
|
|
||||||
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
|
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
|
||||||
# these interpolate tests are added specifically to test onnx.Resize.
|
# these interpolate tests are added specifically to test onnx.Resize.
|
||||||
"InterpolateDynamicModule_sizes_bilinear",
|
"InterpolateDynamicModule_sizes_bilinear",
|
||||||
|
|
Loading…
Reference in New Issue