mirror of https://github.com/llvm/torch-mlir
Fix TupleIndex canonicalizer.
It would change the result type.pull/775/head snapshot-20220503.429
parent
ab5ad7af09
commit
32159c4e54
|
@ -1388,6 +1388,13 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
if (i >= (int64_t)tupleConstruct.elements().size())
|
if (i >= (int64_t)tupleConstruct.elements().size())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
Value element = tupleConstruct.elements()[i];
|
||||||
|
// TODO: We should have a clear picture of whether we want to consistently
|
||||||
|
// allow refinement, and where. It seems desirable to require precise
|
||||||
|
// type equality for TupleConstruct / TupleIndex, but that might break
|
||||||
|
// things.
|
||||||
|
if (element.getType() != op.getType())
|
||||||
|
return failure();
|
||||||
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
|
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
|
|
@ -1055,6 +1055,19 @@ func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor,
|
||||||
return %1 : !torch.tensor
|
return %1 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.prim.TupleIndex$different_types$no_change(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[1,768],f32>) -> !torch.tensor {
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0]] : !torch.tensor<[1,768],f32> -> !torch.tuple<tensor<[1,768],f32>>
|
||||||
|
// CHECK: %[[ELEMENT:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INT0]] : !torch.tuple<tensor<[1,768],f32>>, !torch.int -> !torch.tensor
|
||||||
|
// CHECK: return %[[ELEMENT]] : !torch.tensor
|
||||||
|
func @torch.prim.TupleIndex$different_types$no_change(%arg0: !torch.tensor<[1,768],f32>) -> !torch.tensor {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor<[1,768],f32> -> !torch.tuple<tensor<[1,768],f32>>
|
||||||
|
%1 = torch.prim.TupleIndex %0, %int0 : !torch.tuple<tensor<[1,768],f32>>, !torch.int -> !torch.tensor
|
||||||
|
return %1 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.prim.unchecked_cast$derefine
|
// CHECK-LABEL: func @torch.prim.unchecked_cast$derefine
|
||||||
// CHECK-next: return %arg0 : !torch.list<int>
|
// CHECK-next: return %arg0 : !torch.list<int>
|
||||||
func @torch.prim.unchecked_cast$derefine(%arg0: !torch.list<int>) -> !torch.list<int> {
|
func @torch.prim.unchecked_cast$derefine(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
|
Loading…
Reference in New Issue