canonicalizer: propagate type information across copy and cast ops (#1030)

Prior to this patch, the canonicalizers for `AtenSizeOp` and
`AtenSizeIntOp` succeeded only if the tensor operand's type information
included the size of the requested dimension(s).  We can extend the set
of optimizable cases by propagating types across operations whose result
type matches the input tensor type.

Specifically, this patch enables the canonicalizers for `AtenSizeOp` and
`AtenSizeIntOp` to see past `tensor_static_info_cast`,
`copy.to_vtensor`, and `copy.to_tensor` ops until it reaches the first
op whose result type contains size information for the requested
dimensions, with a maximum bound of 6 parent lookups to avoid indefinite
compilation times.  All other encountered ops cause the canonicalizer to
give up.
pull/1047/head
Ashay Rane 2022-07-12 12:38:37 -07:00 committed by GitHub
parent e5e11e214b
commit ac4d7d10e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 13 deletions

View File

@ -819,14 +819,55 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenSizeOp
//===----------------------------------------------------------------------===//
// Traces at most 6 parents of `value` to determine the tensor type with known
// dimension size or returns failure if such a type was not found. If `dim` is
// `None`, then all dimension's sizes must be known.
static FailureOr<BaseTensorType>
traceKnownSizeTensorType(Value value, llvm::Optional<int64_t> dim) {
// Function to check if we found a type that contains the queried information.
auto foundType = [](BaseTensorType tensorType, llvm::Optional<int64_t>(dim)) {
if (!tensorType.hasSizes())
return false;
if (dim == llvm::None)
return tensorType.areAllSizesKnown();
// If the dimension value is negative, then convert it to a positive value.
ArrayRef<int64_t> sizes = tensorType.getSizes();
*dim = toPositiveDim(*dim, sizes.size());
return isValidDim(*dim, sizes.size()) && sizes[*dim] != kUnknownSize;
};
// Limit the loop count to 6 to avoid indefinite compilation times from
// unbounded IR traversals.
for (auto idx = 0; idx < 6; ++idx) {
if (!value || !value.getType().isa<BaseTensorType>())
return failure();
auto tensorType = value.getType().cast<BaseTensorType>();
if (foundType(tensorType, dim))
return tensorType;
auto op = value.getDefiningOp();
if (!op || !isa<CopyToValueTensorOp, CopyToNonValueTensorOp,
TensorStaticInfoCastOp>(op))
return failure();
// In all ops of interest to us, the source tensor is operand #0.
value = op->getOperand(0);
}
return failure();
}
void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
auto type = op.getOperand().getType().dyn_cast<BaseTensorType>();
if (!type || !type.areAllSizesKnown())
auto type = traceKnownSizeTensorType(op.getOperand(), llvm::None);
if (failed(type))
return rewriter.notifyMatchFailure(op, "all sizes not known");
SmallVector<Value> listElements;
for (int64_t size : type.getSizes()) {
for (int64_t size : type->getSizes()) {
listElements.push_back(rewriter.create<Torch::ConstantIntOp>(
op->getLoc(), rewriter.getI64IntegerAttr(size)));
}
@ -853,18 +894,15 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
auto type = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!type || !type.hasSizes())
int64_t dim;
if (!matchPattern(this->dim(), m_TorchConstantInt(&dim)))
return nullptr;
llvm::Optional<int64_t> dimOpt = matchLegalConstantIndexIntoListOfSize(
this->dim(), type.getSizes().size());
if (!dimOpt)
auto type = traceKnownSizeTensorType(this->self(), dim);
if (failed(type))
return nullptr;
if (type.getSizes()[*dimOpt] == kUnknownSize)
return nullptr;
return IntegerAttr::get(IntegerType::get(getContext(), 64),
type.getSizes()[*dimOpt]);
ArrayRef<int64_t> sizes = type->getSizes();
dim = toPositiveDim(dim, sizes.size());
return IntegerAttr::get(IntegerType::get(getContext(), 64), sizes[dim]);
}
//===----------------------------------------------------------------------===//

View File

@ -1328,3 +1328,32 @@ func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
return %2 : !torch.vtensor<[],si64>
}
// CHECK-LABEL: @torch.aten.size$copy(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list<int> {
// CHECK: %[[TWO:.*]] = torch.constant.int 2
// CHECK: %[[THREE:.*]] = torch.constant.int 3
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[TWO]], %[[THREE]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[LIST]] : !torch.list<int>
// CHECK: }
func.func @torch.aten.size$copy(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.list<int> {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3],f32> to !torch.vtensor
%non_value_tensor = torch.copy.to_tensor %cast : !torch.tensor
%value_tensor = torch.copy.to_vtensor %non_value_tensor : !torch.vtensor
%size = torch.aten.size %value_tensor : !torch.vtensor -> !torch.list<int>
return %size : !torch.list<int>
}
// CHECK-LABEL: @torch.aten.size.int$copy(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.int {
// CHECK: %[[TWO:.*]] = torch.constant.int 2
// CHECK: return %[[TWO]] : !torch.int
// CHECK: }
func.func @torch.aten.size.int$copy(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.int {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3],f32> to !torch.vtensor
%non_value_tensor = torch.copy.to_tensor %cast : !torch.tensor
%value_tensor = torch.copy.to_vtensor %non_value_tensor : !torch.vtensor
%zero = torch.constant.int 0
%size = torch.aten.size.int %value_tensor, %zero : !torch.vtensor, !torch.int -> !torch.int
return %size : !torch.int
}