mirror of https://github.com/llvm/torch-mlir
Delete unnecessary linalg conversion for aten.fmod (#3707)
Follow up cleanup for [this PR](https://github.com/llvm/torch-mlir/pull/3689), which introduced a decomposition for `aten.fmod.Tensor`. This means that the lowering for this operator in linalg is no longer needed. Thanks to @vivekkhandelwal1 for pointing this out. --------- Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>pull/3713/head
parent
7b94ced39a
commit
bc70c50373
|
@ -1282,29 +1282,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return createRemainderPayload(b, loc, converter, payloadArgs, remTensor,
|
||||
operands);
|
||||
}
|
||||
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
|
||||
Type newResultType =
|
||||
cast<RankedTensorType>(converter->convertType(fmod.getType()))
|
||||
.getElementType();
|
||||
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
|
||||
Value result;
|
||||
|
||||
if (isa<mlir::FloatType>(newResultType)) {
|
||||
Value n = b.create<arith::DivFOp>(loc, self, other);
|
||||
n = b.create<math::TruncOp>(loc, n);
|
||||
Value n_y = b.create<arith::MulFOp>(loc, n, other);
|
||||
result = b.create<arith::SubFOp>(loc, self, n_y);
|
||||
} else if (isa<mlir::IntegerType>(newResultType)) {
|
||||
Value n = b.create<arith::DivSIOp>(loc, self, other);
|
||||
Value n_y = b.create<arith::MulIOp>(loc, n, other);
|
||||
result = b.create<arith::SubIOp>(loc, self, n_y);
|
||||
} else {
|
||||
fmod.emitError("Unsupported type encountered for AtenFmodTensorOp.");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(reciprocal.getType()))
|
||||
|
@ -1612,23 +1589,23 @@ public:
|
|||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
|
||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
||||
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
||||
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
||||
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
|
||||
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
|
||||
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
|
||||
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
|
||||
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
||||
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
||||
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
|
||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
||||
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
|
||||
Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp,
|
||||
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
|
||||
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
|
||||
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp,
|
||||
AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp,
|
||||
AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
|
@ -3385,10 +3362,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
|
||||
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
|
||||
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
|
||||
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
|
||||
AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>();
|
||||
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenQuantizePerTensorOp, AtenIscloseOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
Loading…
Reference in New Issue