mirror of https://github.com/llvm/torch-mlir
[torch] Fix lowerings of rshift and lshift (#3665)
I missed adding second operand conversion and adding them to the set of rewrite patterns.pull/3640/merge
parent
9a4c8c606c
commit
b3b8e2e96a
|
@ -850,7 +850,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
cast<RankedTensorType>(converter->convertType(lshiftScalar.getType()))
|
||||
.getElementType();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
Value other =
|
||||
convertScalarToDtype(b, loc, operands[1], dtype,
|
||||
/*srcOriginalDtype=*/operands[1].getType(),
|
||||
/*dstOriginalDtype=*/dtype);
|
||||
return b.create<arith::ShLIOp>(loc, self, other);
|
||||
}
|
||||
if (auto rshiftScalar = dyn_cast<Aten__Rshift__ScalarOp>(op)) {
|
||||
|
@ -858,7 +861,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
cast<RankedTensorType>(converter->convertType(rshiftScalar.getType()))
|
||||
.getElementType();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
Value other =
|
||||
convertScalarToDtype(b, loc, operands[1], dtype,
|
||||
/*srcOriginalDtype=*/operands[1].getType(),
|
||||
/*dstOriginalDtype=*/dtype);
|
||||
return b.create<arith::ShRUIOp>(loc, self, other);
|
||||
}
|
||||
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
|
||||
|
@ -1610,7 +1616,8 @@ public:
|
|||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
||||
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
|
||||
|
@ -3304,10 +3311,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
||||
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
|
||||
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp,
|
||||
AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
|
||||
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
|
||||
|
|
Loading…
Reference in New Issue