Add basic support for list of optional tensors in reduce-op-variants (#971)

This commit adds support for lists of type `list<optional<tensor>>`
where each element in the list is either a `!torch.tensor` or a
`!torch.none`.
pull/1028/head
Ramiro Leal-Cavazos 2022-07-08 13:12:15 -05:00 committed by GitHub
parent f202ae0012
commit 6a72ab4502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 12 deletions

View File

@ -36,6 +36,22 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter,
overwrittenTensor); overwrittenTensor);
} }
static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
if (auto optionalType = type.dyn_cast<OptionalType>()) {
Type newContainedType = getContainerOrTensorTypeWithValueSemantics(
optionalType.getContainedType());
return OptionalType::get(newContainedType);
} else if (auto listType = type.dyn_cast<ListType>()) {
Type newContainedType =
getContainerOrTensorTypeWithValueSemantics(listType.getContainedType());
return ListType::get(newContainedType);
} else if (auto tensorType = type.dyn_cast<NonValueTensorType>()) {
return tensorType.getWithValueSemantics();
} else {
return nullptr;
}
}
namespace { namespace {
// Convert value semantic ops operating on mutable arrays to instead operate on // Convert value semantic ops operating on mutable arrays to instead operate on
// immutable tensors. // immutable tensors.
@ -76,23 +92,35 @@ public:
continue; continue;
// TODO: Handle optional type in list type. // TODO: Handle optional type in list type.
if (listType.getContainedType().isa<OptionalType>()) { if (auto optionalType =
listType.getContainedType().dyn_cast<OptionalType>()) {
if (!llvm::all_of(listConstruct.elements(), [](Value val) { if (!llvm::all_of(listConstruct.elements(), [](Value val) {
return val.getType().isa<NonValueTensorType>(); return val.getType().isa<NonValueTensorType, Torch::NoneType>();
})) })) {
rewriter.cancelRootUpdate(op);
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: list containing optional type is not " op, "unimplemented: list containing optional type is not "
"handled."); "handled.");
}
} }
auto newListElements = llvm::to_vector<4>(llvm::map_range( auto newListElements = llvm::to_vector(llvm::map_range(
listConstruct.elements(), [&](Value tensor) -> Value { listConstruct.elements(), [&](Value tensor) -> Value {
return rewriter.create<CopyToValueTensorOp>(op->getLoc(), tensor); if (tensor.getType().isa<NonValueTensorType>()) {
return rewriter.create<CopyToValueTensorOp>(op->getLoc(),
tensor);
}
return tensor;
})); }));
Type newListType = getContainerOrTensorTypeWithValueSemantics(listType);
if (!newListType) {
rewriter.cancelRootUpdate(op);
return rewriter.notifyMatchFailure(
op, "Unable to convert list type to value semantics.");
}
opOperand.set(rewriter.create<PrimListConstructOp>( opOperand.set(rewriter.create<PrimListConstructOp>(
op->getLoc(), op->getLoc(), newListType, newListElements));
Torch::ListType::get(newListElements.front().getType()),
newListElements));
} else if (auto optionalType = operandType.dyn_cast<OptionalType>()) { } else if (auto optionalType = operandType.dyn_cast<OptionalType>()) {
// TODO: A more general way to handle the optional type is to // TODO: A more general way to handle the optional type is to
// introduce a `copy.to_optional_vtensor` op. // introduce a `copy.to_optional_vtensor` op.

View File

@ -269,6 +269,22 @@ public:
}; };
} // namespace } // namespace
namespace {
class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimUncheckedCastOp op,
PatternRewriter &rewriter) const override {
if (!isValidSubtype(op.x().getType(), op.result().getType())) {
return rewriter.notifyMatchFailure(
op, "input tensor type is not a valid subtype of result type");
}
rewriter.replaceOp(op, op.x());
return success();
}
};
} // namespace
static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum, static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
PatternRewriter &rewriter, PatternRewriter &rewriter,
bool &madeChange) { bool &madeChange) {
@ -396,6 +412,7 @@ class SimplifyShapeCalculationsPass
patterns.insert<AbstractlyInterpretListOpsWithinABlock>(context); patterns.insert<AbstractlyInterpretListOpsWithinABlock>(context);
patterns.insert<DecomposeAtenSizeOp>(context); patterns.insert<DecomposeAtenSizeOp>(context);
patterns.insert<RefineShapeCalculateOp>(context); patterns.insert<RefineShapeCalculateOp>(context);
patterns.insert<FoldPrimUncheckedCastOp>(context);
PrimIfOp::getCanonicalizationPatterns(patterns, context); PrimIfOp::getCanonicalizationPatterns(patterns, context);
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);

View File

@ -120,8 +120,8 @@ func.func @torch.tensor.literal() -> !torch.tensor {
// CHECK-SAME: (!torch.tensor<[2,3],si64>) -> !torch.list<optional<tensor<[2,3],si64>>> // CHECK-SAME: (!torch.tensor<[2,3],si64>) -> !torch.list<optional<tensor<[2,3],si64>>>
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor<[5],f32> // CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor<[5],f32>
// CHECK: %[[INDICES_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDICES]] : !torch.vtensor<[2,3],si64> // CHECK: %[[INDICES_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDICES]] : !torch.vtensor<[2,3],si64>
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDICES_VTENSOR]] : (!torch.vtensor<[2,3],si64>) -> !torch.list<vtensor<[2,3],si64>> // CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDICES_VTENSOR]] : (!torch.vtensor<[2,3],si64>) -> !torch.list<optional<vtensor<[2,3],si64>>>
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<vtensor<[2,3],si64>> -> !torch.vtensor // CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<optional<vtensor<[2,3],si64>>> -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor // CHECK: return %[[RET]] : !torch.tensor
func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor { func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor {
@ -130,6 +130,27 @@ func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
// CHECK-SAME: %[[INDICES_0:.*]]: !torch.tensor<[2,3],si64>,
// CHECK-SAME: %[[INDICES_1:.*]]: !torch.tensor<[3],si64>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[NONE]], %[[INDICES_0]], %[[INDICES_1]] :
// CHECK-SAME: (!torch.none, !torch.tensor<[2,3],si64>, !torch.tensor<[3],si64>) -> !torch.list<optional<tensor>>
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor<[5],f32>
// CHECK: %[[INDICES_VTENSOR_0:.*]] = torch.copy.to_vtensor %[[INDICES_0]] : !torch.vtensor<[2,3],si64>
// CHECK: %[[INDICES_VTENSOR_1:.*]] = torch.copy.to_vtensor %[[INDICES_1]] : !torch.vtensor<[3],si64>
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[NONE]], %[[INDICES_VTENSOR_0]], %[[INDICES_VTENSOR_1]] : (!torch.none, !torch.vtensor<[2,3],si64>, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<optional<vtensor>> -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(%self: !torch.tensor<[5],f32>, %indices0: !torch.tensor<[2,3],si64>, %indices1: !torch.tensor<[3],si64>) -> !torch.tensor {
%none = torch.constant.none
%tensor_optional_list = torch.prim.ListConstruct %none, %indices0, %indices1 : (!torch.none, !torch.tensor<[2,3],si64>, !torch.tensor<[3],si64>) -> !torch.list<optional<tensor>>
%ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<optional<tensor>> -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func.func @torch.aten.uniform_( // CHECK-LABEL: func.func @torch.aten.uniform_(
// CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float, // CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float,
// CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor { // CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor {
@ -184,9 +205,9 @@ func.func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDEX]] : (!torch.tensor) -> !torch.list<optional<tensor>> // CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDEX]] : (!torch.tensor) -> !torch.list<optional<tensor>>
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor // CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor
// CHECK: %[[INDEX_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDEX]] : !torch.vtensor // CHECK: %[[INDEX_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDEX]] : !torch.vtensor
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDEX_VTENSOR]] : (!torch.vtensor) -> !torch.list<vtensor> // CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDEX_VTENSOR]] : (!torch.vtensor) -> !torch.list<optional<vtensor>>
// CHECK: %[[VALUES_VTENSOR:.*]] = torch.copy.to_vtensor %[[VALUES]] : !torch.vtensor // CHECK: %[[VALUES_VTENSOR:.*]] = torch.copy.to_vtensor %[[VALUES]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.valsem.aten.index_put_impl %[[SELF_VTENSOR]], %[[INDICES_LIST]], %[[VALUES_VTENSOR]], %[[TRUE]], %[[FALSE]] : !torch.vtensor, !torch.list<vtensor>, !torch.vtensor, !torch.bool, !torch.bool -> !torch.vtensor // CHECK: %[[VRET:.*]] = torch.valsem.aten.index_put_impl %[[SELF_VTENSOR]], %[[INDICES_LIST]], %[[VALUES_VTENSOR]], %[[TRUE]], %[[FALSE]] : !torch.vtensor, !torch.list<optional<vtensor>>, !torch.vtensor, !torch.bool, !torch.bool -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[SELF]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[SELF]] : !torch.vtensor, !torch.tensor

View File

@ -406,3 +406,34 @@ func.func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor
} : !torch.vtensor } : !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
// CHECK-LABEL: func.func @fold_prim_unchecked_cast_op(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor {
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.shape.calculate {
// CHECK: %[[VAL_5:.*]] = torch.tensor_static_info_cast %[[VAL_0]] : !torch.vtensor to !torch.vtensor<[?,?],unk>
// CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[?,?],unk>
// CHECK: } shapes {
// CHECK: %[[VAL_6:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_2]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_3]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[VAL_8]] : !torch.list<int>
// CHECK: } : !torch.vtensor<[?,?],unk>
// CHECK: %[[VAL_9:.*]] = torch.tensor_static_info_cast %[[VAL_10:.*]] : !torch.vtensor<[?,?],unk> to !torch.vtensor
// CHECK: return %[[VAL_9]] : !torch.vtensor
// CHECK: }
func.func @fold_prim_unchecked_cast_op(%arg0: !torch.vtensor, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor {
%int0 = torch.constant.int 0
%tensor_list = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[?,?],si64>) -> !torch.list<optional<vtensor>>
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%getitem = torch.aten.__getitem__.t %tensor_list, %int0 : !torch.list<optional<vtensor>>, !torch.int -> !torch.optional<vtensor>
%unchecked_cast = torch.prim.unchecked_cast %getitem : !torch.optional<vtensor> -> !torch.vtensor
%size = torch.aten.size %unchecked_cast : !torch.vtensor -> !torch.list<int>
torch.shape.calculate.yield.shapes %size : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}