Fix TupleIndex canonicalizer.

It would change the result type.
pull/775/head snapshot-20220503.429
Sean Silva 2022-05-03 09:12:09 +00:00
parent ab5ad7af09
commit 32159c4e54
2 changed files with 20 additions and 0 deletions

View File

@ -1388,6 +1388,13 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
if (i >= (int64_t)tupleConstruct.elements().size())
return failure();
Value element = tupleConstruct.elements()[i];
// TODO: We should have a clear picture of whether we want to consistently
// allow refinement, and where. It seems desirable to require precise
// type equality for TupleConstruct / TupleIndex, but that might break
// things.
if (element.getType() != op.getType())
return failure();
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
return success();
});

View File

@ -1055,6 +1055,19 @@ func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor,
return %1 : !torch.tensor
}
// CHECK-LABEL: func @torch.prim.TupleIndex$different_types$no_change(
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[1,768],f32>) -> !torch.tensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0]] : !torch.tensor<[1,768],f32> -> !torch.tuple<tensor<[1,768],f32>>
// CHECK: %[[ELEMENT:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INT0]] : !torch.tuple<tensor<[1,768],f32>>, !torch.int -> !torch.tensor
// CHECK: return %[[ELEMENT]] : !torch.tensor
func @torch.prim.TupleIndex$different_types$no_change(%arg0: !torch.tensor<[1,768],f32>) -> !torch.tensor {
%int0 = torch.constant.int 0
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor<[1,768],f32> -> !torch.tuple<tensor<[1,768],f32>>
%1 = torch.prim.TupleIndex %0, %int0 : !torch.tuple<tensor<[1,768],f32>>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}
// CHECK-LABEL: func @torch.prim.unchecked_cast$derefine
// CHECK-next: return %arg0 : !torch.list<int>
func @torch.prim.unchecked_cast$derefine(%arg0: !torch.list<int>) -> !torch.list<int> {