torch-mlir/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp

1202 lines
45 KiB
C++
Raw Normal View History

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
#include "mlir/IR/Iterators.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
LogicalResult materializeFolds(ImplicitLocOpBuilder b,
ArrayRef<OpFoldResult> fold,
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVectorImpl<Value> &values) {
for (auto f : fold) {
if (auto val = dyn_cast<Value>(f)) {
values.push_back(val);
continue;
}
if (auto attr = dyn_cast<Attribute>(f)) {
if (auto val = dyn_cast<FloatAttr>(attr)) {
values.push_back(b.create<Torch::ConstantFloatOp>(
b.getType<Torch::FloatType>(), val));
continue;
}
if (auto val = dyn_cast<IntegerAttr>(attr)) {
values.push_back(
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
b.create<Torch::ConstantIntOp>(val.getValue().getSExtValue()));
continue;
}
}
return failure();
}
return success();
}
LogicalResult getListOperands(Value value, SmallVector<Value> &vals) {
auto list = value.getDefiningOp<Torch::PrimListConstructOp>();
if (!list)
return failure();
for (auto operand : list.getOperands())
vals.push_back(operand);
return success();
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
LogicalResult getListFromTensor(Value value, SmallVector<OpFoldResult> &vals) {
constexpr int64_t kMaxFold = 16;
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
if (auto tensor = value.getDefiningOp<Torch::AtenTensorOp>()) {
SmallVector<Value> unfolded;
LogicalResult gotList = getListOperands(tensor.getData(), unfolded);
vals = getAsOpFoldResult(unfolded);
return gotList;
}
if (auto full = value.getDefiningOp<Torch::AtenFullOp>()) {
auto ty = cast<ValueTensorType>(full.getType());
if (!ty.areAllSizesKnown() || ty.getSizes().size() != 1)
return failure();
if (ty.getSizes()[0] > kMaxFold)
return failure();
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
vals.resize(vals.size() + ty.getSizes()[0],
getAsOpFoldResult(full.getFillValue()));
return success();
}
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
if (auto unsqueeze = value.getDefiningOp<Torch::AtenUnsqueezeOp>()) {
Value usqSelf = unsqueeze.getSelf();
if (auto numToTensor =
usqSelf.getDefiningOp<Torch::PrimNumToTensorScalarOp>()) {
vals.push_back(getAsOpFoldResult(numToTensor.getA()));
return success();
}
}
// A common rank 0 tensor producer
if (auto numToTensor =
value.getDefiningOp<Torch::PrimNumToTensorScalarOp>()) {
vals.push_back(getAsOpFoldResult(numToTensor.getA()));
return success();
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
// Last supported case: ValueTensorLiteralOp
auto literalOp = value.getDefiningOp<Torch::ValueTensorLiteralOp>();
if (!literalOp)
return failure();
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
// Check the type.
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
auto ty = cast<ValueTensorType>(literalOp.getType());
if (!ty.hasSizes() || ty.getSizes().size() > 1)
return failure();
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
// make sure the type is not unsigned here before trying to materialize
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
auto intTy = dyn_cast_or_null<IntegerType>(ty.getDtype());
if (!intTy || intTy.isUnsigned())
return failure();
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
// if we have a rank 0 literal, we will be adding one element to the list
int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1;
if (listSize > kMaxFold)
return failure();
// check for a splat or dense attr
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
auto splattr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue());
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(literalOp.getValue());
if (!splattr && !denseAttr)
return failure();
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
// These are not mutually exclusive, so try splat first.
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
if (splattr) {
auto attr = splattr.getSplatValue<Attribute>();
vals.resize((int64_t)vals.size() + listSize, attr);
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
return success();
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
}
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
// remaining case: denseAttr
if ((int64_t)denseAttr.getValues<Attribute>().size() != listSize)
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
return failure();
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
for (auto e : denseAttr.getValues<Attribute>())
vals.push_back(e);
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
return success();
}
Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy,
SmallVector<Value> &listValues) {
auto dimList = b.create<Torch::PrimListConstructOp>(
b.getType<Torch::ListType>(listValues.front().getType()), listValues);
Value cstNone = b.create<Torch::ConstantNoneOp>();
Value cstFalse = b.create<Torch::ConstantBoolOp>(b.getBoolAttr(false));
return b.create<Torch::AtenTensorOp>(resultTy, dimList, cstNone, cstNone,
cstFalse);
}
} // namespace
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
/// ------ Propagation Patterns ------ ///
// The general goal of these patterns is to convert SomeTensorOp to [scalarOps
// -> PrimListOfInts -> AtenTensorOp] Since these tensorized shape calculation
// ops are chained together, sequences like OpA -> OpB will propagate OpA first:
// [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to
// getListFromTensor(A), and further propagate scalarization.
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
namespace {
class PropagateAtenBroadcastToPattern
: public OpRewritePattern<AtenBroadcastToOp> {
public:
using OpRewritePattern<AtenBroadcastToOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBroadcastToOp op,
PatternRewriter &rewriter) const override {
constexpr int64_t kMaxFold = 16;
// for tensor<si64>, or tensor<1xsi64>, broadcasted to tensor<nxsi64>, grab
// the element and convert to a full op.
auto ty = cast<ValueTensorType>(op.getType());
if (!ty.areAllSizesKnown() || ty.getSizes().size() != 1)
return failure();
if (ty.getSizes()[0] > kMaxFold)
return failure();
SmallVector<OpFoldResult> fillFold;
if (failed(getListFromTensor(op.getSelf(), fillFold)) ||
fillFold.size() != 1)
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<Value, 1> fillVals;
if (failed(materializeFolds(b, fillFold, fillVals)))
return failure();
Value size = b.create<Torch::ConstantIntOp>(ty.getSizes().front());
Value sizeList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
size);
Value none = b.create<Torch::ConstantNoneOp>();
Value cstFalse = b.create<Torch::ConstantBoolOp>(false);
rewriter.replaceOpWithNewOp<AtenFullOp>(op, ty, sizeList, fillVals.front(),
none, none, none, cstFalse);
return success();
}
};
} // namespace
namespace {
class PropagateAtenShapeToTensorPattern
: public OpRewritePattern<Aten_ShapeAsTensorOp> {
public:
using OpRewritePattern<Aten_ShapeAsTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ShapeAsTensorOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
ImplicitLocOpBuilder b(loc, rewriter);
auto self = op.getSelf();
auto selfTy = cast<BaseTensorType>(self.getType());
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "self has unknown rank");
int64_t rank = selfTy.getSizes().size();
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> dims;
for (int64_t i = 0; i < rank; ++i) {
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
auto iv = b.create<Torch::ConstantIntOp>(i);
dims.push_back(b.createOrFold<Torch::AtenSizeIntOp>(
rewriter.getType<Torch::IntType>(), self, iv));
}
SmallVector<Value> materializedDims;
if (failed(materializeFolds(b, dims, materializedDims))) {
return failure();
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
Value result =
constructAtenTensorOpFromList(b, op.getType(), materializedDims);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class PropagateAtenCatPattern : public OpRewritePattern<AtenCatOp> {
public:
using OpRewritePattern<AtenCatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCatOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
constexpr int64_t kMaxFold = 16;
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
if (!resultTy.hasSizes() || resultTy.getSizes().size() != 1 ||
!resultTy.areAllSizesKnown())
return failure();
if (resultTy.getSizes().front() > kMaxFold)
return failure();
if (!resultTy.hasDtype())
return failure();
SmallVector<Value> tensors;
if (failed(getListOperands(op.getTensors(), tensors)))
return failure();
SmallVector<OpFoldResult> scalars;
for (auto element : tensors) {
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
llvm::SmallVector<OpFoldResult> delisted;
if (failed(getListFromTensor(element, delisted)))
return rewriter.notifyMatchFailure(op, "unknown op fold type");
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
for (auto scalar : delisted)
scalars.push_back(scalar);
}
SmallVector<Value> values;
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
if (failed(materializeFolds(b, scalars, values)) || values.empty())
return rewriter.notifyMatchFailure(op, "unable to materialize constants");
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
Value result = constructAtenTensorOpFromList(b, resultTy, values);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class PropagateAtenIndexSelectPattern
: public OpRewritePattern<AtenIndexSelectOp> {
public:
using OpRewritePattern<AtenIndexSelectOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexSelectOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> elements;
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "requires a constant dim");
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> idxFolds;
if (failed(getListFromTensor(op.getIndex(), idxFolds)))
return rewriter.notifyMatchFailure(op, "requires a constant index");
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "requires known rank");
auto selfShape = selfTy.getSizes();
int64_t selfRank = selfShape.size();
dim = dim < 0 ? dim + selfRank : dim;
int64_t dimLength = elements.size();
if (selfShape[dim] != dimLength)
return rewriter.notifyMatchFailure(
op, "dim length does not match number of elements");
for (int64_t i = 0; i < selfRank; ++i) {
if (i == dim)
continue;
if (selfShape[i] != 1)
return rewriter.notifyMatchFailure(op,
"expects unary non-dim dimension");
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> selected;
for (auto idx : idxFolds) {
auto attr = dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(idx));
if (!attr)
return failure();
int64_t indexInt = attr.getValue().getSExtValue();
indexInt = indexInt < 0 ? indexInt + dimLength : indexInt;
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
if (indexInt < 0 || indexInt >= dimLength)
return failure();
selected.push_back(elements[indexInt]);
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<Value> materializedSelected;
if (failed(materializeFolds(b, selected, materializedSelected)))
return failure();
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
Value result =
constructAtenTensorOpFromList(b, op.getType(), materializedSelected);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
// Conversion attempts to handle some common propagatable slice cases, namely
// splatted values, no-op slices, known list of values, or any case where a
// new construction can be generated from a previous set of scalars allowing
// the parent tensor to be bypassed.
class PropagateAtenSliceTensorPattern
: public OpRewritePattern<AtenSliceTensorOp> {
public:
using OpRewritePattern<AtenSliceTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSliceTensorOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> elements;
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();
int64_t dim, start, end, step;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "requires a constant dim");
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(op, "requires a constant start");
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
return rewriter.notifyMatchFailure(op, "requires a constant end");
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(op, "requires a constant step");
if (step < 0)
return rewriter.notifyMatchFailure(op, "requires a positive step value");
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
auto selfShape = selfTy.getSizes();
int64_t selfRank = selfShape.size();
// Correct for negative indexing:
dim = dim < 0 ? dim + selfRank : dim;
int64_t dimLength = elements.size();
start = start < 0 ? start + dimLength : start;
end = end < 0 ? end + dimLength : end;
start = start < 0 ? 0 : start;
end = end < 0 ? 0 : end;
end = end > dimLength ? dimLength : end;
if (selfShape[dim] != dimLength)
return rewriter.notifyMatchFailure(
op, "dim length does not match number of elements");
for (int64_t i = 0; i < selfRank; ++i) {
if (i == dim)
continue;
if (selfShape[i] != 1)
return rewriter.notifyMatchFailure(op,
"expects unary non-dim dimension");
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> selected;
for (int i = start; i < end; i += step)
selected.push_back(elements[i]);
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<Value> values;
if (failed(materializeFolds(b, selected, values)))
return failure();
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
Value result = constructAtenTensorOpFromList(b, op.getType(), values);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class PropagateAtenWhereSelfPattern : public OpRewritePattern<AtenWhereSelfOp> {
public:
using OpRewritePattern<AtenWhereSelfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenWhereSelfOp op,
PatternRewriter &rewriter) const override {
Value condition = op.getCondition();
Value self = op.getSelf();
Value other = op.getOther();
auto conditionTy = dyn_cast<Torch::ValueTensorType>(condition.getType());
if (!conditionTy || !conditionTy.hasSizes() ||
conditionTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad condition type");
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad self type");
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad other type");
int64_t conditionSize = selfTy.getSizes()[0];
int64_t selfSize = selfTy.getSizes()[0];
int64_t otherSize = otherTy.getSizes()[0];
if (selfSize != otherSize || selfSize != conditionSize)
return rewriter.notifyMatchFailure(
op,
"unimplemented: support for propogating with implicit broadcasting.");
constexpr int64_t kMaxFold = 16;
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold)
return rewriter.notifyMatchFailure(op,
"arguments are dynamic or too big");
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> conditionFolds, selfFolds, otherFolds;
if (failed(getListFromTensor(condition, conditionFolds)) ||
failed(getListFromTensor(self, selfFolds)) ||
failed(getListFromTensor(other, otherFolds)))
return failure();
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<Value> conditionList, selfList, otherList;
if (failed(materializeFolds(b, conditionFolds, conditionList)) ||
failed(materializeFolds(b, selfFolds, selfList)) ||
failed(materializeFolds(b, otherFolds, otherList)))
return failure();
SmallVector<Value> whereVals;
auto rank0IntTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), selfTy.getDtype());
auto rank0BoolTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), conditionTy.getDtype());
for (uint64_t i = 0; i < selfList.size(); i++) {
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
Value rank0Cond = b.create<Torch::PrimNumToTensorScalarOp>(
rank0BoolTy, conditionList[i]);
Value rank0Self =
b.create<Torch::PrimNumToTensorScalarOp>(rank0IntTy, selfList[i]);
Value rank0Other =
b.create<Torch::PrimNumToTensorScalarOp>(rank0IntTy, otherList[i]);
Value rank0Where = b.create<AtenWhereSelfOp>(rank0IntTy, rank0Cond,
rank0Self, rank0Other);
whereVals.push_back(
b.create<AtenItemOp>(rewriter.getType<Torch::IntType>(), rank0Where));
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
Value result = constructAtenTensorOpFromList(b, op.getType(), whereVals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class PropagateAtenEqTensorPattern : public OpRewritePattern<AtenEqTensorOp> {
public:
using OpRewritePattern<AtenEqTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEqTensorOp op,
PatternRewriter &rewriter) const override {
Value self = op.getSelf();
Value other = op.getOther();
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad self type");
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad other type");
int64_t selfSize = selfTy.getSizes()[0];
int64_t otherSize = otherTy.getSizes()[0];
if (selfSize != otherSize)
return rewriter.notifyMatchFailure(
op,
"unimplemented: support for propogating with implicit broadcasting.");
constexpr int64_t kMaxFold = 16;
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold ||
otherSize == Torch::kUnknownSize || otherSize > kMaxFold)
return rewriter.notifyMatchFailure(op,
"self or other is dynamic or too big");
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> selfFolds, otherFolds;
if (failed(getListFromTensor(self, selfFolds)) ||
failed(getListFromTensor(other, otherFolds)))
return rewriter.notifyMatchFailure(op, "failed to get list from tensor");
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<Value> selfList, otherList;
if (failed(materializeFolds(b, selfFolds, selfList)) ||
failed(materializeFolds(b, otherFolds, otherList)))
return rewriter.notifyMatchFailure(op, "failed to materialize folds");
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> eqBoolFolds;
for (uint64_t i = 0; i < selfList.size(); i++) {
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
OpFoldResult eqInt =
b.createOrFold<AtenEqIntOp>(selfList[i], otherList[i]);
if (auto eqIntVal = dyn_cast<Value>(eqInt))
eqInt = b.createOrFold<AtenIntBoolOp>(eqIntVal);
// if eqInt was an Attribute, it will materialize to a constant int op,
// which is what we want.
eqBoolFolds.push_back(eqInt);
}
SmallVector<Value> eqVals;
if (failed(materializeFolds(b, eqBoolFolds, eqVals))) {
return failure();
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
Value result = constructAtenTensorOpFromList(b, op.getType(), eqVals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
public:
using OpRewritePattern<AtenItemOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenItemOp op,
PatternRewriter &rewriter) const override {
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<OpFoldResult> elements;
Value self = op.getSelf();
auto selfTy = cast<ValueTensorType>(self.getType());
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
// Rank 0 item op prop
if (selfTy.getSizes().size() == 0) {
auto numToTensor = self.getDefiningOp<Torch::PrimNumToTensorScalarOp>();
auto squeezeDim = self.getDefiningOp<AtenSqueezeDimOp>();
if (!squeezeDim && !numToTensor)
return rewriter.notifyMatchFailure(op,
"unhandled item of rank 0 operand");
if (numToTensor) {
rewriter.replaceOp(op, numToTensor.getA());
return success();
}
rewriter.replaceOpWithNewOp<AtenItemOp>(op, op.getType(),
squeezeDim.getSelf());
return success();
}
// Rank 1 item op prop
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();
if (elements.size() != 1)
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
return rewriter.notifyMatchFailure(op, "expected one element");
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
SmallVector<Value, 1> materialized;
if (failed(materializeFolds(b, elements, materialized)))
return failure();
rewriter.replaceOp(op, materialized.front());
return success();
}
};
} // namespace
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
namespace {
template <typename OpTy> struct ArithmeticHelper {
static LogicalResult getAlphaAndVerify(OpTy &op, int64_t &alpha) {
alpha = 1;
return success();
}
};
template <> struct ArithmeticHelper<AtenAddTensorOp> {
static LogicalResult getAlphaAndVerify(AtenAddTensorOp &op, int64_t &alpha) {
if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1)
return failure();
return success();
}
};
template <> struct ArithmeticHelper<AtenSubTensorOp> {
static LogicalResult getAlphaAndVerify(AtenSubTensorOp &op, int64_t &alpha) {
if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1)
return failure();
return success();
}
};
template <typename OpTy, typename ScalarOpTy>
class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Check type
auto resultTy = cast<ValueTensorType>(op.getType());
if (resultTy.getSizes().size() > 1)
return rewriter.notifyMatchFailure(op, "unsupported: rank > 1");
if (!resultTy.hasDtype() || !isa<mlir::IntegerType>(resultTy.getDtype()))
return rewriter.notifyMatchFailure(op, "not an int type");
int64_t alpha;
if (failed(ArithmeticHelper<OpTy>::getAlphaAndVerify(op, alpha)))
return rewriter.notifyMatchFailure(op, "alpha must be 1");
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<OpFoldResult> selfFold, otherFold;
if (failed(getListFromTensor(op.getSelf(), selfFold)) ||
failed(getListFromTensor(op.getOther(), otherFold)) ||
selfFold.size() != otherFold.size())
return failure();
SmallVector<Value> selfVals, otherVals;
if (failed(materializeFolds(b, selfFold, selfVals)) ||
failed(materializeFolds(b, otherFold, otherVals)))
return failure();
SmallVector<OpFoldResult> resultFolds;
for (uint64_t i = 0; i < selfVals.size(); i++) {
resultFolds.push_back(b.createOrFold<ScalarOpTy>(
selfVals[i].getType(), selfVals[i], otherVals[i]));
}
SmallVector<Value> resultVals;
if (failed(materializeFolds(b, resultFolds, resultVals)))
return failure();
if (resultTy.getSizes().size() == 0) {
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, resultTy, resultVals.front());
return success();
}
Value result = constructAtenTensorOpFromList(b, resultTy, resultVals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
/// ------ Fold Patterns ------ ///
// These are shape-specific folding patterns
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
namespace {
class FoldAtenEqIntPattern : public OpRewritePattern<AtenEqIntOp> {
public:
using OpRewritePattern<AtenEqIntOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEqIntOp op,
PatternRewriter &rewriter) const override {
// replaces (size.int == 0) with false and adds an assert
// these comparisons are getting generated because onnx.Reshape considers 0
// to mean "don't change this dim". However, if the size we are passing to
// onnx.Reshape is a tensor dim, this is definitely never supposed to be
// interpreted as "don't change this dim".
int64_t otherInt;
if (!matchPattern(op.getB(), m_TorchConstantInt(&otherInt)) ||
otherInt != 0)
return failure();
// in case the shape is a product of two ints, check each
if (auto mulOp = op.getA().getDefiningOp<AtenMulIntOp>()) {
Value self = mulOp.getA();
Value other = mulOp.getB();
Value selfEq = rewriter.create<AtenEqIntOp>(op.getLoc(), self, op.getB());
Value otherEq =
rewriter.create<AtenEqIntOp>(op.getLoc(), other, op.getB());
rewriter.replaceOpWithNewOp<Aten__Or__BoolOp>(op, selfEq, otherEq);
return success();
}
// if lhs is size.int op, assert size > 0 and replace with false.
if (auto sizeOp = op.getA().getDefiningOp<AtenSizeIntOp>()) {
Value selfGtOther = rewriter.create<AtenGtIntOp>(
op.getLoc(), op.getType(), op.getA(), op.getB());
rewriter.create<Torch::RuntimeAssertOp>(
op.getLoc(), selfGtOther,
rewriter.getStringAttr("Expected dim size > 0."));
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOp(op, cstFalse);
return success();
}
return failure();
}
};
} // namespace
namespace {
class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
public:
using OpRewritePattern<AtenTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTensorOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<Value> elements;
if (failed(getListOperands(op.getData(), elements)))
return failure();
if (elements.size() < 1)
return rewriter.notifyMatchFailure(op, "no elements");
auto front = elements.front();
for (auto element : elements)
if (element != front)
return rewriter.notifyMatchFailure(op, "multiple elements found");
if (elements.size() != 1)
return rewriter.notifyMatchFailure(op, "expected no elements");
auto resultTy = cast<BaseTensorType>(op.getType());
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "dynamic output shape");
auto loc = op.getLoc();
SmallVector<Value> sizes;
for (auto size : resultTy.getSizes())
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(size)));
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(), 1);
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
one);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
rewriter.replaceOpWithNewOp<AtenFullOp>(
op, resultTy, sizeList, elements.front(), none, none, none, cstFalse);
return success();
}
};
} // namespace
namespace {
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
template <typename SqueezeOp>
class FoldAtenSqueezePattern : public OpRewritePattern<SqueezeOp> {
public:
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
using OpRewritePattern<SqueezeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SqueezeOp op,
PatternRewriter &rewriter) const override {
auto resultTy = cast<ValueTensorType>(op.getType());
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "Unknown result shape");
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
Value self = op.getSelf();
if (auto atenFull = self.getDefiningOp<AtenFullOp>()) {
// in the rank 0 case, just return the rank 0 scalar
if (resultTy.getSizes().size() == 0) {
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, resultTy, atenFull.getFillValue());
return success();
}
SmallVector<Value> sizes;
for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i)
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getI64IntegerAttr(i)));
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(),
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
sizes);
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(op, resultTy, sizeList,
atenFull.getFillValue(),
none, none, none, none);
return success();
}
return failure();
}
};
} // namespace
namespace {
class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
public:
using OpRewritePattern<AtenWhereSelfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenWhereSelfOp op,
PatternRewriter &rewriter) const override {
auto getRoot = [](Value v) {
while (true) {
if (auto numToTensor =
v.getDefiningOp<Torch::PrimNumToTensorScalarOp>()) {
v = numToTensor.getA();
continue;
}
break;
}
return v;
};
auto self = getRoot(op.getSelf());
auto other = getRoot(op.getOther());
if (self == other) {
rewriter.replaceOp(op, op.getSelf());
return success();
}
auto selfSize = self.getDefiningOp<Torch::AtenSizeIntOp>();
auto otherSize = other.getDefiningOp<Torch::AtenSizeIntOp>();
if (selfSize && otherSize) {
if (selfSize.getSelf() != otherSize.getSelf())
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
return rewriter.notifyMatchFailure(op, "sizes not of same tensor");
int64_t dimSelf, dimOther;
if ((selfSize.getDim() != otherSize.getDim()) &&
(!matchPattern(selfSize.getDim(), m_TorchConstantInt(&dimSelf)) ||
!matchPattern(otherSize.getDim(), m_TorchConstantInt(&dimOther)) ||
(dimSelf != dimOther)))
return rewriter.notifyMatchFailure(op, "sizes not of same dim");
rewriter.replaceOp(op, op.getSelf());
return success();
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
return rewriter.notifyMatchFailure(op, "unable to fold");
}
};
} // namespace
namespace {
class FoldAtenUnsqueezePattern : public OpRewritePattern<AtenUnsqueezeOp> {
public:
using OpRewritePattern<AtenUnsqueezeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenUnsqueezeOp op,
PatternRewriter &rewriter) const override {
auto resultTy = cast<ValueTensorType>(op.getType());
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "Unknown result shape");
if (auto atenFull = op.getSelf().getDefiningOp<AtenFullOp>()) {
SmallVector<Value> sizes;
for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i)
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getI64IntegerAttr(i)));
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(),
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
sizes);
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(op, resultTy, sizeList,
atenFull.getFillValue(),
none, none, none, none);
return success();
}
auto squeezeOp = op.getSelf().getDefiningOp<AtenSqueezeDimOp>();
if (squeezeOp && resultTy.getSizes().size() == 1) {
rewriter.replaceOp(op, squeezeOp.getSelf());
return success();
}
return failure();
}
};
} // namespace
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
/// ------ Canonicalization Patterns ------ ///
namespace {
// This is a specific pattern for converting views like [?,...,?,lastDim] ->
// [?,...,?,factor0,factor1] to unflatten, and views like
// [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is
// possible to infer that all but last shared dim match
// TODO: move this to an actual canonicalizer for view after deleting the
// conflicting decompositions for flatten/unflatten -> view.
class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
public:
using OpRewritePattern<AtenViewOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenViewOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> viewSizes;
if (failed(getListOperands(op.getSize(), viewSizes)))
return rewriter.notifyMatchFailure(
op, "view size must be from a list construct");
auto selfTy = dyn_cast<Torch::ValueTensorType>(op.getSelf().getType());
if (!selfTy || !selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "missing input type or sizes");
auto resultTy = dyn_cast<Torch::ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasSizes() ||
resultTy.getSizes().size() != viewSizes.size())
return rewriter.notifyMatchFailure(op, "missing result type or sizes");
int64_t inRank = selfTy.getSizes().size();
int64_t outRank = resultTy.getSizes().size();
SmallVector<int64_t> sizes(selfTy.getSizes());
int64_t endMatchingDim = -1;
// input sizes vs. provided view sizes comparison loop
for (int64_t i = 0; i < std::min(outRank, inRank); i++) {
int64_t providedSize;
bool providedStatic =
matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize));
// if sizes[i] is static, it must match a constant in viewSizes[i]
if (sizes[i] != Torch::kUnknownSize) {
if (!providedStatic)
return rewriter.notifyMatchFailure(
op, "unsupported: found static input dim, but unable to match "
"provided view size on a constant. See position : " +
std::to_string(i));
if (providedSize != sizes[i]) {
endMatchingDim = i;
break;
}
continue;
}
// the remaining assumes sizes[i] is dynamic
// if provided dim is static, we can't verify it is a flatten/unflatten
// unless -1
if (i == outRank - 1 && providedStatic && providedSize == -1) {
endMatchingDim = i;
break;
}
if (providedStatic)
return rewriter.notifyMatchFailure(
op, "unexpected static view dim corresponding to dynamic input dim "
"at position : " +
std::to_string(i));
auto sizeIntOp = viewSizes[i].getDefiningOp<AtenSizeIntOp>();
// if we don't have a size int op on self, fail
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
return rewriter.notifyMatchFailure(
op, "expected dynamic view dim to come from a corresponding "
"size.int op. See position : " +
std::to_string(i));
int64_t dim;
// if the dim of the size int op doesn't match, fail
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
dim != i)
return rewriter.notifyMatchFailure(
op,
"size int op dim cannot be matched to current dim at position : " +
std::to_string(i));
// passing the previous checks means viewSizes[i] = aten.size.int(self,
// i), so continue
}
// if all dims match and the ranks are equal, fold
if (endMatchingDim == -1 && inRank == outRank) {
rewriter.replaceOp(op, op.getSelf());
return success();
}
if (endMatchingDim > -1 && inRank > outRank) {
// only support flattening last dim
if (endMatchingDim != outRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: output has more than back dim mismatching");
// flatten
Value start =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value end =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), inRank - 1);
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
op, resultTy, op.getSelf(), start, end);
return success();
}
if (endMatchingDim > -1 && inRank < outRank) {
// only support unflattening last dim
if (endMatchingDim != inRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: input has more than back dim mismatching");
// unflatten
Value dim =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value primList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(),
ArrayRef<Value>(viewSizes.begin() + endMatchingDim, viewSizes.end()));
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
op, resultTy, op.getSelf(), dim, primList);
return success();
}
// examples that might reach this:
// input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants)
// input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes)
// input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes)
return rewriter.notifyMatchFailure(
op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) +
", inRank=" + std::to_string(inRank) +
", outRank=" + std::to_string(outRank));
}
};
} // namespace
namespace {
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
public:
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter) const override {
for (auto use : op->getResults())
if (!use.use_empty())
return failure();
rewriter.eraseOp(op);
return success();
}
};
} // namespace
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
namespace {
bool isSourceOpForShapeScalarization(Operation *op) {
return llvm::isa<AtenSizeIntOp, Torch::ConstantIntOp, Torch::ConstantBoolOp,
Aten_ShapeAsTensorOp, Torch::ValueTensorLiteralOp>(op);
}
bool isPrimListOfInts(Operation *op) {
auto primListOp = dyn_cast<Torch::PrimListConstructOp>(op);
if (!primListOp)
return false;
auto listType = dyn_cast<Torch::ListType>(primListOp.getType());
if (!listType)
return false;
return llvm::isa<Torch::IntType>(listType.getContainedType());
}
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
bool isAnchorOp(Operation *op) {
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
isPrimListOfInts(op);
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
patterns.getContext());
}
void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
patterns.add<CanonicalizeAtenViewPattern>(patterns.getContext());
}
void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
// A note on division: onnx.Div from int, int -> int types rounds towards
// zero. The torch DivTensorOp actually doesn't allow returning an int dtype,
// but this was artificially plummbed through. Unfortunately, there is no
// scalar trunc div op in torch; however, we can safely assume all operands
// are positive so floor divide should be a sufficient scalar replacement.
patterns.insert<
PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
patterns.getContext());
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
}
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
patterns.insert<RemoveUnusedPattern<Torch::AtenIntBoolOp>,
RemoveUnusedPattern<Torch::AtenEqIntOp>,
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
RemoveUnusedPattern<Torch::AtenFullOp>,
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::ConstantBoolOp>,
RemoveUnusedPattern<Torch::ConstantIntOp>,
RemoveUnusedPattern<Torch::ConstantNoneOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
patterns.getContext());
}
} // namespace
namespace {
class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
// populate patterns
populateScalarizationPropagationPatterns(patterns);
populateScalarizationFoldPatterns(patterns);
populateScalarizationCanonicalizePatterns(patterns);
populateScalarizationRemovePatterns(patterns);
context->getLoadedDialect<mlir::arith::ArithDialect>()
->getCanonicalizationPatterns(patterns);
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
// don't load torch canonicalization patterns, since these may lead to
// issues with propagation
// walk func op bottom-up to collect a SetVector of shape-related operations
// When we pass this SetVector to the pattern rewrite driver, it will
// process the operations top-down, thereby propagating scalarization
// starting from sources.
auto funcOp = getOperation();
llvm::SetVector<Operation *> shapeCalculationOps;
funcOp.walk<WalkOrder::PostOrder, mlir::ReverseIterator>(
[&](Operation *op) {
// Walking bottom-up, start adding ops when we reach an anchor point
// (a prim list of ints)
Add More Scalarize Shapes Patterns (#3810) ### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
2024-10-22 08:42:39 +08:00
if (isAnchorOp(op)) {
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
shapeCalculationOps.insert(op);
return;
}
// add view ops for now until the decompositions for flatten and
// unflatten are removed.
if (isa<AtenViewOp>(op)) {
shapeCalculationOps.insert(op);
return;
}
// Insert the op if any of it's consumers have already been identified
// as a shape calculation op. To avoid adding the producer of
// something like a size.int op, don't add ops when their consumer is
// a source op for shape scalarization. Here is some sample IR:
// ------
// %0 = aten.matmul %arg0, %arg1 : ... -> !torch.vtensor<[?,?,?],f32>
// %1 = aten.size.int %0, %int0 : !torch.int
// %2 = prim.ListConstruct %1 : (!torch.int) -> !torch.list<int>
// return %2 : !torch.list<int>
// ------
// In this example, don't add the matmul (%0), or it's producers, to
// shapeCalculationOps. It's consumer (%1) is indeed a shape
// calculation op, but the size.int op is an elementary unit of shape
// computation. No futher gathering of producers is necessary to
// reduce this. Similarly, don't add the `self` of a view op.
for (OpOperand &use : op->getUses()) {
Operation *userOp = use.getOwner();
if (shapeCalculationOps.contains(userOp) &&
!isSourceOpForShapeScalarization(userOp) &&
!isa<AtenViewOp>(userOp)) {
shapeCalculationOps.insert(op);
return;
}
}
});
GreedyRewriteConfig config;
// When propagating, we need to go back and clean up aten.Tensor ops that
// have been futher propagated. It is also necessary to add newly created
// ops for custom folding after scalarizing a where.self op.
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
if (failed(applyOpPatternsAndFold(shapeCalculationOps.getArrayRef(),
std::move(patterns), config))) {
return signalPassFailure();
}
Rework Scalarize Shapes Pass (#3799) This is a first step towards reworking the scalarize-shapes pass which has been integral to our ONNX frontend path detangling shape computations. ## Purpose: 1. Restrict the scope of the pass to only apply to op sequences which are used to compute shapes. 2. Make the pass more efficient by applying patterns in an appropriate order for scalarization propagation. 3. Report failed scalarization patterns for easier debugging (Not yet implemented). I can't seem to find a good path for this right now to capture the right diagnostics. I'd like to defer this addition to a later patch so we can add some high-value patterns to this pass in the meantime. With these changes, some reworking of the conversions themselves will be necessary. 1. The removal of the SqueezeDim fold pattern was an appropriate fix to avoid folding a pattern that may be needed to propagate further. The reversal of pattern application order uncovered this bug. The addition of rank 0 item logic was added to replace the functionality needed from the squeeze dim pattern. 2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to allow processing value tensor literals without immediately materializing the ints. This should factor out a significant portion of code that was used in specific cases to handle constants. ## RFC 1: Currently, we are going to add all prim list of int ops to the worklist. Can anyone identify problems with uniformly anchoring on prim lists of ints? E.g. Does there exist a Torch Op satisfying all of the following conditions: 1. Accepts a list of constant ints, LIST, as an input 2. The role of LIST is **not** shape related. All the examples I can think of are indeed shape related: padding ints passed to a pad op, kernel size ints passed to a conv op, size ints passed to a view op, etc. 4. The LIST is not gotten entirely from scalars already. If there does not exist a torch op satisfying all three of those conditions, I think it will be safe to "anchor" on prim lists of ints. ### Conclusion for RFC 1: I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for all references of `AnyTorchListOfTorchIntType` and verified this will not be problematic to apply in any of those cases. ## RFC 2: What should I use to report failed scalarization? Like my dumb idea was just to walk back through the func op after applying the passes and check if anything in the worklist is still a tensor. If so, emit/log a warning. It certainly works, since you can just look at the warnings and start debugging from the last printed warning upwards, but there has to be a better way to handle this without walking back through the func.func op. ### Conclusion for RFC 2: I tried a few things without much success. The fundamental problem is that identifying the cause of a failed scalarization could be myriad: 1. We could be missing a pattern for an op entirely: E.g., a pattern we need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp -> AtenMulIntOp). 2. We could fail a scalarization pattern because it should fold instead. This is specifically the case for rank0 where.self ops. These ops MUST fold, or we need to have custom lowering logic for the rank 0 case. 3. Walking through the func op a second time and emiting a warning for ops that have tensor result types seems to give locations that are inconsistent or hard to track in the converted IR. Doing this on IR that doesn't apply any patterns seems to give decent information, but it's still dramatically insufficient considering how complex these patterns can get, and still takes manually reading IR to try and figure out what is really blocking the simplification. I'd like to skip out on fleshing out the error reporting for now and come back to it after iterating a few time on the patterns.
2024-10-22 01:47:19 +08:00
// TODO: Warn when failing to process operations in the worklist.
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createScalarizeShapesPass() {
return std::make_unique<ScalarizeShapesPass>();
}