torch.prim.TupleIndex: Adjust tensor types when folding.

In cases where a refinement/derefinement was needed, we didn't fold.

Fixes https://github.com/llvm/torch-mlir/issues/863
pull/839/head
Sean Silva 2022-05-19 13:12:58 +00:00
parent 2af53ce434
commit 3fb54cba4c
2 changed files with 17 additions and 13 deletions

View File

@ -1414,14 +1414,20 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
if (i >= (int64_t)tupleConstruct.elements().size())
return failure();
Value element = tupleConstruct.elements()[i];
// TODO: We should have a clear picture of whether we want to consistently
// allow refinement, and where. It seems desirable to require precise
// type equality for TupleConstruct / TupleIndex, but that might break
// things.
if (element.getType() != op.getType())
Value replacement = tupleConstruct.elements()[i];
if (replacement.getType() != op.getType()) {
if (op.getType().isa<BaseTensorType>()) {
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
op.getLoc(), op.getType(), replacement);
} else {
return failure();
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
}
}
rewriter.replaceOp(op, replacement);
return success();
});
}

View File

@ -1055,16 +1055,14 @@ func.func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.te
return %1 : !torch.tensor
}
// CHECK-LABEL: func.func @torch.prim.TupleIndex$different_types$no_change(
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[1,768],f32>) -> !torch.tensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0]] : !torch.tensor<[1,768],f32> -> !torch.tuple<tensor<[1,768],f32>>
// CHECK: %[[ELEMENT:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INT0]] : !torch.tuple<tensor<[1,768],f32>>, !torch.int -> !torch.tensor
// CHECK: return %[[ELEMENT]] : !torch.tensor
func.func @torch.prim.TupleIndex$different_types$no_change(%arg0: !torch.tensor<[1,768],f32>) -> !torch.tensor {
// CHECK-LABEL: func.func @torch.prim.TupleIndex$adjust_type$tensor(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[7],f32>) -> !torch.tensor {
// CHECK: %[[RETURN:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor<[7],f32> to !torch.tensor
// CHECK: return %[[RETURN]] : !torch.tensor
func.func @torch.prim.TupleIndex$adjust_type$tensor(%arg0: !torch.tensor<[7],f32>) -> !torch.tensor {
%int0 = torch.constant.int 0
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor<[1,768],f32> -> !torch.tuple<tensor<[1,768],f32>>
%1 = torch.prim.TupleIndex %0, %int0 : !torch.tuple<tensor<[1,768],f32>>, !torch.int -> !torch.tensor
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor<[7],f32> -> !torch.tuple<tensor<[7],f32>>
%1 = torch.prim.TupleIndex %0, %int0 : !torch.tuple<tensor<[7],f32>>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}