mirror of https://github.com/llvm/torch-mlir
[Stablehlo]fix CumsumInputDtypeInt32Module_basic on stablehlo backend. (#2797)
Code used for testing.For the location of CumsumInputDtypeInt32Module in
the repo you can see
[here](311b6b0286/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py (L4148)
).
```python
import torch
import torch_mlir
class CumsumInputDtypeInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, val):
return torch.ops.aten.cumsum(val, 1)
module = torch_mlir.compile(CumsumInputDtypeInt32Module(), [torch.randn(2, 7, 4).to(torch.int32)], output_type="stablehlo")
print(module.operation.get_asm())
```
After fixing the bugs.
```
module attributes {torch.debug_module_name = "CumsumInputDtypeInt32Module"} {
func.func @forward(%arg0: tensor<2x7x4xi32>) -> tensor<2x7x4xi64> {
%0 = stablehlo.constant dense<0> : tensor<i64>
%1 = stablehlo.convert %arg0 : (tensor<2x7x4xi32>) -> tensor<2x7x4xi64>
%2 = "stablehlo.reduce_window"(%1, %0) ({
^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
%3 = stablehlo.add %arg1, %arg2 : tensor<i64>
stablehlo.return %3 : tensor<i64>
}) {padding = dense<[[0, 0], [6, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 7, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<2x7x4xi64>, tensor<i64>) -> tensor<2x7x4xi64>
return %2 : tensor<2x7x4xi64>
}
}
```
pull/2803/head
snapshot-20240125.1094
parent
f6f890520b
commit
e581b33f96
|
@ -569,11 +569,13 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||||
|
auto outTy =
|
||||||
|
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
|
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||||
|
inputTy = input.getType().cast<RankedTensorType>();
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
auto inputRank = inputTy.getRank();
|
auto inputRank = inputTy.getRank();
|
||||||
auto inputShape = inputTy.getShape();
|
auto inputShape = inputTy.getShape();
|
||||||
auto outTy =
|
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
|
||||||
|
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||||
|
|
Loading…
Reference in New Issue