mirror of https://github.com/llvm/torch-mlir
torch.prim.TupleIndex: Adjust tensor types when folding.
In cases where a refinement/derefinement was needed, we didn't fold. Fixes https://github.com/llvm/torch-mlir/issues/863pull/839/head
parent
2af53ce434
commit
3fb54cba4c
|
@ -1414,14 +1414,20 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
if (i >= (int64_t)tupleConstruct.elements().size())
|
||||
return failure();
|
||||
|
||||
Value element = tupleConstruct.elements()[i];
|
||||
// TODO: We should have a clear picture of whether we want to consistently
|
||||
// allow refinement, and where. It seems desirable to require precise
|
||||
// type equality for TupleConstruct / TupleIndex, but that might break
|
||||
// things.
|
||||
if (element.getType() != op.getType())
|
||||
return failure();
|
||||
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
|
||||
Value replacement = tupleConstruct.elements()[i];
|
||||
if (replacement.getType() != op.getType()) {
|
||||
if (op.getType().isa<BaseTensorType>()) {
|
||||
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
||||
op.getLoc(), op.getType(), replacement);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, replacement);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1055,16 +1055,14 @@ func.func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.te
|
|||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.prim.TupleIndex$different_types$no_change(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[1,768],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0]] : !torch.tensor<[1,768],f32> -> !torch.tuple<tensor<[1,768],f32>>
|
||||
// CHECK: %[[ELEMENT:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INT0]] : !torch.tuple<tensor<[1,768],f32>>, !torch.int -> !torch.tensor
|
||||
// CHECK: return %[[ELEMENT]] : !torch.tensor
|
||||
func.func @torch.prim.TupleIndex$different_types$no_change(%arg0: !torch.tensor<[1,768],f32>) -> !torch.tensor {
|
||||
// CHECK-LABEL: func.func @torch.prim.TupleIndex$adjust_type$tensor(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[7],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[RETURN:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor<[7],f32> to !torch.tensor
|
||||
// CHECK: return %[[RETURN]] : !torch.tensor
|
||||
func.func @torch.prim.TupleIndex$adjust_type$tensor(%arg0: !torch.tensor<[7],f32>) -> !torch.tensor {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor<[1,768],f32> -> !torch.tuple<tensor<[1,768],f32>>
|
||||
%1 = torch.prim.TupleIndex %0, %int0 : !torch.tuple<tensor<[1,768],f32>>, !torch.int -> !torch.tensor
|
||||
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor<[7],f32> -> !torch.tuple<tensor<[7],f32>>
|
||||
%1 = torch.prim.TupleIndex %0, %int0 : !torch.tuple<tensor<[7],f32>>, !torch.int -> !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue