mirror of https://github.com/llvm/torch-mlir
Add basic support for list of optional tensors in reduce-op-variants (#971)
This commit adds support for lists of type `list<optional<tensor>>` where each element in the list is either a `!torch.tensor` or a `!torch.none`.pull/1028/head
parent
f202ae0012
commit
6a72ab4502
|
@ -36,6 +36,22 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter,
|
|||
overwrittenTensor);
|
||||
}
|
||||
|
||||
static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
|
||||
if (auto optionalType = type.dyn_cast<OptionalType>()) {
|
||||
Type newContainedType = getContainerOrTensorTypeWithValueSemantics(
|
||||
optionalType.getContainedType());
|
||||
return OptionalType::get(newContainedType);
|
||||
} else if (auto listType = type.dyn_cast<ListType>()) {
|
||||
Type newContainedType =
|
||||
getContainerOrTensorTypeWithValueSemantics(listType.getContainedType());
|
||||
return ListType::get(newContainedType);
|
||||
} else if (auto tensorType = type.dyn_cast<NonValueTensorType>()) {
|
||||
return tensorType.getWithValueSemantics();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Convert value semantic ops operating on mutable arrays to instead operate on
|
||||
// immutable tensors.
|
||||
|
@ -76,23 +92,35 @@ public:
|
|||
continue;
|
||||
|
||||
// TODO: Handle optional type in list type.
|
||||
if (listType.getContainedType().isa<OptionalType>()) {
|
||||
if (auto optionalType =
|
||||
listType.getContainedType().dyn_cast<OptionalType>()) {
|
||||
if (!llvm::all_of(listConstruct.elements(), [](Value val) {
|
||||
return val.getType().isa<NonValueTensorType>();
|
||||
}))
|
||||
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
|
||||
})) {
|
||||
rewriter.cancelRootUpdate(op);
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: list containing optional type is not "
|
||||
"handled.");
|
||||
}
|
||||
}
|
||||
|
||||
auto newListElements = llvm::to_vector<4>(llvm::map_range(
|
||||
auto newListElements = llvm::to_vector(llvm::map_range(
|
||||
listConstruct.elements(), [&](Value tensor) -> Value {
|
||||
return rewriter.create<CopyToValueTensorOp>(op->getLoc(), tensor);
|
||||
if (tensor.getType().isa<NonValueTensorType>()) {
|
||||
return rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
||||
tensor);
|
||||
}
|
||||
return tensor;
|
||||
}));
|
||||
|
||||
Type newListType = getContainerOrTensorTypeWithValueSemantics(listType);
|
||||
if (!newListType) {
|
||||
rewriter.cancelRootUpdate(op);
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unable to convert list type to value semantics.");
|
||||
}
|
||||
opOperand.set(rewriter.create<PrimListConstructOp>(
|
||||
op->getLoc(),
|
||||
Torch::ListType::get(newListElements.front().getType()),
|
||||
newListElements));
|
||||
op->getLoc(), newListType, newListElements));
|
||||
} else if (auto optionalType = operandType.dyn_cast<OptionalType>()) {
|
||||
// TODO: A more general way to handle the optional type is to
|
||||
// introduce a `copy.to_optional_vtensor` op.
|
||||
|
|
|
@ -269,6 +269,22 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimUncheckedCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!isValidSubtype(op.x().getType(), op.result().getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor type is not a valid subtype of result type");
|
||||
}
|
||||
rewriter.replaceOp(op, op.x());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
|
||||
PatternRewriter &rewriter,
|
||||
bool &madeChange) {
|
||||
|
@ -396,6 +412,7 @@ class SimplifyShapeCalculationsPass
|
|||
patterns.insert<AbstractlyInterpretListOpsWithinABlock>(context);
|
||||
patterns.insert<DecomposeAtenSizeOp>(context);
|
||||
patterns.insert<RefineShapeCalculateOp>(context);
|
||||
patterns.insert<FoldPrimUncheckedCastOp>(context);
|
||||
|
||||
PrimIfOp::getCanonicalizationPatterns(patterns, context);
|
||||
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
|
||||
|
|
|
@ -120,8 +120,8 @@ func.func @torch.tensor.literal() -> !torch.tensor {
|
|||
// CHECK-SAME: (!torch.tensor<[2,3],si64>) -> !torch.list<optional<tensor<[2,3],si64>>>
|
||||
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor<[5],f32>
|
||||
// CHECK: %[[INDICES_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDICES]] : !torch.vtensor<[2,3],si64>
|
||||
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDICES_VTENSOR]] : (!torch.vtensor<[2,3],si64>) -> !torch.list<vtensor<[2,3],si64>>
|
||||
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<vtensor<[2,3],si64>> -> !torch.vtensor
|
||||
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDICES_VTENSOR]] : (!torch.vtensor<[2,3],si64>) -> !torch.list<optional<vtensor<[2,3],si64>>>
|
||||
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<optional<vtensor<[2,3],si64>>> -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: return %[[RET]] : !torch.tensor
|
||||
func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor {
|
||||
|
@ -130,6 +130,27 @@ func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
|
||||
// CHECK-SAME: %[[INDICES_0:.*]]: !torch.tensor<[2,3],si64>,
|
||||
// CHECK-SAME: %[[INDICES_1:.*]]: !torch.tensor<[3],si64>) -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[NONE]], %[[INDICES_0]], %[[INDICES_1]] :
|
||||
// CHECK-SAME: (!torch.none, !torch.tensor<[2,3],si64>, !torch.tensor<[3],si64>) -> !torch.list<optional<tensor>>
|
||||
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor<[5],f32>
|
||||
// CHECK: %[[INDICES_VTENSOR_0:.*]] = torch.copy.to_vtensor %[[INDICES_0]] : !torch.vtensor<[2,3],si64>
|
||||
// CHECK: %[[INDICES_VTENSOR_1:.*]] = torch.copy.to_vtensor %[[INDICES_1]] : !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[NONE]], %[[INDICES_VTENSOR_0]], %[[INDICES_VTENSOR_1]] : (!torch.none, !torch.vtensor<[2,3],si64>, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
|
||||
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<optional<vtensor>> -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: return %[[RET]] : !torch.tensor
|
||||
func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(%self: !torch.tensor<[5],f32>, %indices0: !torch.tensor<[2,3],si64>, %indices1: !torch.tensor<[3],si64>) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%tensor_optional_list = torch.prim.ListConstruct %none, %indices0, %indices1 : (!torch.none, !torch.tensor<[2,3],si64>, !torch.tensor<[3],si64>) -> !torch.list<optional<tensor>>
|
||||
%ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<optional<tensor>> -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.uniform_(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float,
|
||||
// CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor {
|
||||
|
@ -184,9 +205,9 @@ func.func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
|
|||
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDEX]] : (!torch.tensor) -> !torch.list<optional<tensor>>
|
||||
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor
|
||||
// CHECK: %[[INDEX_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDEX]] : !torch.vtensor
|
||||
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDEX_VTENSOR]] : (!torch.vtensor) -> !torch.list<vtensor>
|
||||
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDEX_VTENSOR]] : (!torch.vtensor) -> !torch.list<optional<vtensor>>
|
||||
// CHECK: %[[VALUES_VTENSOR:.*]] = torch.copy.to_vtensor %[[VALUES]] : !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.valsem.aten.index_put_impl %[[SELF_VTENSOR]], %[[INDICES_LIST]], %[[VALUES_VTENSOR]], %[[TRUE]], %[[FALSE]] : !torch.vtensor, !torch.list<vtensor>, !torch.vtensor, !torch.bool, !torch.bool -> !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.valsem.aten.index_put_impl %[[SELF_VTENSOR]], %[[INDICES_LIST]], %[[VALUES_VTENSOR]], %[[TRUE]], %[[FALSE]] : !torch.vtensor, !torch.list<optional<vtensor>>, !torch.vtensor, !torch.bool, !torch.bool -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[SELF]] : !torch.vtensor, !torch.tensor
|
||||
|
|
|
@ -406,3 +406,34 @@ func.func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor
|
|||
} : !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @fold_prim_unchecked_cast_op(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor {
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = torch.shape.calculate {
|
||||
// CHECK: %[[VAL_5:.*]] = torch.tensor_static_info_cast %[[VAL_0]] : !torch.vtensor to !torch.vtensor<[?,?],unk>
|
||||
// CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[?,?],unk>
|
||||
// CHECK: } shapes {
|
||||
// CHECK: %[[VAL_6:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_2]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_3]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.shape.calculate.yield.shapes %[[VAL_8]] : !torch.list<int>
|
||||
// CHECK: } : !torch.vtensor<[?,?],unk>
|
||||
// CHECK: %[[VAL_9:.*]] = torch.tensor_static_info_cast %[[VAL_10:.*]] : !torch.vtensor<[?,?],unk> to !torch.vtensor
|
||||
// CHECK: return %[[VAL_9]] : !torch.vtensor
|
||||
// CHECK: }
|
||||
func.func @fold_prim_unchecked_cast_op(%arg0: !torch.vtensor, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor {
|
||||
%int0 = torch.constant.int 0
|
||||
%tensor_list = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[?,?],si64>) -> !torch.list<optional<vtensor>>
|
||||
%0 = torch.shape.calculate {
|
||||
torch.shape.calculate.yield %arg0 : !torch.vtensor
|
||||
} shapes {
|
||||
%getitem = torch.aten.__getitem__.t %tensor_list, %int0 : !torch.list<optional<vtensor>>, !torch.int -> !torch.optional<vtensor>
|
||||
%unchecked_cast = torch.prim.unchecked_cast %getitem : !torch.optional<vtensor> -> !torch.vtensor
|
||||
%size = torch.aten.size %unchecked_cast : !torch.vtensor -> !torch.list<int>
|
||||
torch.shape.calculate.yield.shapes %size : !torch.list<int>
|
||||
} : !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue