mirror of https://github.com/llvm/torch-mlir
canonicalizer: propagate type information across copy and cast ops (#1030)
Prior to this patch, the canonicalizers for `AtenSizeOp` and `AtenSizeIntOp` succeeded only if the tensor operand's type information included the size of the requested dimension(s). We can extend the set of optimizable cases by propagating types across operations whose result type matches the input tensor type. Specifically, this patch enables the canonicalizers for `AtenSizeOp` and `AtenSizeIntOp` to see past `tensor_static_info_cast`, `copy.to_vtensor`, and `copy.to_tensor` ops until it reaches the first op whose result type contains size information for the requested dimensions, with a maximum bound of 6 parent lookups to avoid indefinite compilation times. All other encountered ops cause the canonicalizer to give up.pull/1047/head
parent
e5e11e214b
commit
ac4d7d10e0
|
@ -819,14 +819,55 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
// AtenSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Traces at most 6 parents of `value` to determine the tensor type with known
|
||||
// dimension size or returns failure if such a type was not found. If `dim` is
|
||||
// `None`, then all dimension's sizes must be known.
|
||||
static FailureOr<BaseTensorType>
|
||||
traceKnownSizeTensorType(Value value, llvm::Optional<int64_t> dim) {
|
||||
// Function to check if we found a type that contains the queried information.
|
||||
auto foundType = [](BaseTensorType tensorType, llvm::Optional<int64_t>(dim)) {
|
||||
if (!tensorType.hasSizes())
|
||||
return false;
|
||||
|
||||
if (dim == llvm::None)
|
||||
return tensorType.areAllSizesKnown();
|
||||
|
||||
// If the dimension value is negative, then convert it to a positive value.
|
||||
ArrayRef<int64_t> sizes = tensorType.getSizes();
|
||||
*dim = toPositiveDim(*dim, sizes.size());
|
||||
return isValidDim(*dim, sizes.size()) && sizes[*dim] != kUnknownSize;
|
||||
};
|
||||
|
||||
// Limit the loop count to 6 to avoid indefinite compilation times from
|
||||
// unbounded IR traversals.
|
||||
for (auto idx = 0; idx < 6; ++idx) {
|
||||
if (!value || !value.getType().isa<BaseTensorType>())
|
||||
return failure();
|
||||
|
||||
auto tensorType = value.getType().cast<BaseTensorType>();
|
||||
if (foundType(tensorType, dim))
|
||||
return tensorType;
|
||||
|
||||
auto op = value.getDefiningOp();
|
||||
if (!op || !isa<CopyToValueTensorOp, CopyToNonValueTensorOp,
|
||||
TensorStaticInfoCastOp>(op))
|
||||
return failure();
|
||||
|
||||
// In all ops of interest to us, the source tensor is operand #0.
|
||||
value = op->getOperand(0);
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
||||
auto type = op.getOperand().getType().dyn_cast<BaseTensorType>();
|
||||
if (!type || !type.areAllSizesKnown())
|
||||
auto type = traceKnownSizeTensorType(op.getOperand(), llvm::None);
|
||||
if (failed(type))
|
||||
return rewriter.notifyMatchFailure(op, "all sizes not known");
|
||||
SmallVector<Value> listElements;
|
||||
for (int64_t size : type.getSizes()) {
|
||||
for (int64_t size : type->getSizes()) {
|
||||
listElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
||||
}
|
||||
|
@ -853,18 +894,15 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto type = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
if (!type || !type.hasSizes())
|
||||
int64_t dim;
|
||||
if (!matchPattern(this->dim(), m_TorchConstantInt(&dim)))
|
||||
return nullptr;
|
||||
|
||||
llvm::Optional<int64_t> dimOpt = matchLegalConstantIndexIntoListOfSize(
|
||||
this->dim(), type.getSizes().size());
|
||||
if (!dimOpt)
|
||||
auto type = traceKnownSizeTensorType(this->self(), dim);
|
||||
if (failed(type))
|
||||
return nullptr;
|
||||
if (type.getSizes()[*dimOpt] == kUnknownSize)
|
||||
return nullptr;
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
||||
type.getSizes()[*dimOpt]);
|
||||
ArrayRef<int64_t> sizes = type->getSizes();
|
||||
dim = toPositiveDim(dim, sizes.size());
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64), sizes[dim]);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1328,3 +1328,32 @@ func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],
|
|||
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @torch.aten.size$copy(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list<int> {
|
||||
// CHECK: %[[TWO:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[THREE:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[TWO]], %[[THREE]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: return %[[LIST]] : !torch.list<int>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.size$copy(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.list<int> {
|
||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3],f32> to !torch.vtensor
|
||||
%non_value_tensor = torch.copy.to_tensor %cast : !torch.tensor
|
||||
%value_tensor = torch.copy.to_vtensor %non_value_tensor : !torch.vtensor
|
||||
%size = torch.aten.size %value_tensor : !torch.vtensor -> !torch.list<int>
|
||||
return %size : !torch.list<int>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @torch.aten.size.int$copy(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.int {
|
||||
// CHECK: %[[TWO:.*]] = torch.constant.int 2
|
||||
// CHECK: return %[[TWO]] : !torch.int
|
||||
// CHECK: }
|
||||
func.func @torch.aten.size.int$copy(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.int {
|
||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3],f32> to !torch.vtensor
|
||||
%non_value_tensor = torch.copy.to_tensor %cast : !torch.tensor
|
||||
%value_tensor = torch.copy.to_vtensor %non_value_tensor : !torch.vtensor
|
||||
%zero = torch.constant.int 0
|
||||
%size = torch.aten.size.int %value_tensor, %zero : !torch.vtensor, !torch.int -> !torch.int
|
||||
return %size : !torch.int
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue