[torch] Fix lowerings of rshift and lshift (#3665)

I missed adding second operand conversion and adding them to the set of
rewrite patterns.
pull/3640/merge
Rob Suderman 2024-08-23 20:27:18 -07:00 committed by GitHub
parent 9a4c8c606c
commit b3b8e2e96a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 15 additions and 7 deletions

View File

@ -850,7 +850,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
cast<RankedTensorType>(converter->convertType(lshiftScalar.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value other =
convertScalarToDtype(b, loc, operands[1], dtype,
/*srcOriginalDtype=*/operands[1].getType(),
/*dstOriginalDtype=*/dtype);
return b.create<arith::ShLIOp>(loc, self, other);
}
if (auto rshiftScalar = dyn_cast<Aten__Rshift__ScalarOp>(op)) {
@ -858,7 +861,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
cast<RankedTensorType>(converter->convertType(rshiftScalar.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value other =
convertScalarToDtype(b, loc, operands[1], dtype,
/*srcOriginalDtype=*/operands[1].getType(),
/*dstOriginalDtype=*/dtype);
return b.create<arith::ShRUIOp>(loc, self, other);
}
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
@ -1610,7 +1616,8 @@ public:
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
@ -3304,10 +3311,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp,
AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,