mirror of https://github.com/llvm/torch-mlir
Add Some Folders For Small Reshape Ops (#3813)
### Changes 1. Folders for view-like ops: `aten.view`, `aten.flatten.using_ints`, and `aten.unflatten.int` 2. Folder for transpose 3. Extended support for the `aten.slice.Tensor` op folder to include negative strides. ### Motivation The biggest motivation for this patch is to fold the extremely convoluted ir that gets generated when exporting a pytorch model with an `aten.pad` op to ONNX, then re-importing and lowering back to torch. For example, the verbose output of the e2e test `PadModule_basic` with `-c onnx`: ```mlir module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} { %none = torch.constant.none %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> %2 = torch.operator "onnx.ConstantOfShape"(%0) {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> %3 = torch.operator "onnx.Concat"(%1, %2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> %5 = torch.operator "onnx.Reshape"(%3, %4) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %10 = torch.operator "onnx.Slice"(%5, %7, %8, %6, %9) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> %11 = torch.operator "onnx.Transpose"(%10) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> %12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__8> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %13 = torch.operator "onnx.Reshape"(%11, %12) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> %14 = torch.operator "onnx.Cast"(%13) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__9> : tensor<f32>} : () -> !torch.vtensor<[],f32> %16 = torch.operator "onnx.Pad"(%arg0, %14, %15) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> return %16 : !torch.vtensor<[?,?,?,?],f32> } } {-# dialect_resources: { builtin: { _: "0x080000000400000000000000", __1: "0x080000000000000000000000010000000000000002000000000000000300000000000000", __2: "0x080000000000000000000000", __3: "0x08000000FFFFFFFFFFFFFFFF0200000000000000", __4: "0x080000000000000000000000", __5: "0x08000000FFFFFFFFFFFFFFFF", __6: "0x080000000100000000000080", __7: "0x08000000FFFFFFFFFFFFFFFF", __8: "0x08000000FFFFFFFFFFFFFFFF", __9: "0x080000000000C03F" } } #-} ``` Get's converted to the torch IR: ```mlir module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} { %float1.500000e00 = torch.constant.float 1.500000e+00 %int-9223372036854775807 = torch.constant.int -9223372036854775807 %int-1 = torch.constant.int -1 %int7 = torch.constant.int 7 %int6 = torch.constant.int 6 %int5 = torch.constant.int 5 %int3 = torch.constant.int 3 %int8 = torch.constant.int 8 %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int4 = torch.constant.int 4 %int0 = torch.constant.int 0 %0 = torch.vtensor.literal(dense<[0, 1, 2, 3, 0, 0, 0, 0]> : tensor<8xsi64>) : !torch.vtensor<[8],si64> %1 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int> %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64> %3 = torch.aten.slice.Tensor %2, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> %4 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> %5 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> %6 = torch.aten.view %4, %5 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64> %7 = torch.aten.slice.Tensor %6, %int0, %int0, %int1, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int %9 = torch.aten.slice.Tensor %6, %int0, %int1, %int2, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int %11 = torch.aten.slice.Tensor %6, %int0, %int2, %int3, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int %13 = torch.aten.slice.Tensor %6, %int0, %int3, %int4, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int %15 = torch.aten.slice.Tensor %6, %int0, %int4, %int5, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %16 = torch.aten.item %15 : !torch.vtensor<[1],si64> -> !torch.int %17 = torch.aten.slice.Tensor %6, %int0, %int5, %int6, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %18 = torch.aten.item %17 : !torch.vtensor<[1],si64> -> !torch.int %19 = torch.aten.slice.Tensor %6, %int0, %int6, %int7, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %20 = torch.aten.item %19 : !torch.vtensor<[1],si64> -> !torch.int %21 = torch.aten.slice.Tensor %6, %int0, %int7, %int8, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int %23 = torch.prim.ListConstruct %14, %22, %12, %20, %10, %18, %8, %16 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> %24 = torch.aten.constant_pad_nd %arg0, %23, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32> return %24 : !torch.vtensor<[?,?,?,?],f32> } } ``` ***All of these operations are useless***. It is literally the result of needing to reverse (and change the lexicographic order hierarchy of) padding ints provided via torch vs. ONNX pad ops, which is then subsequently UNDONE by our ONNX->Torch lowering (represented in the ordering of the generated list construct). With the added folders in this patch, the torch IR becomes: ``` module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} { %float1.500000e00 = torch.constant.float 1.500000e+00 %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> %1 = torch.aten.constant_pad_nd %arg0, %0, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32> return %1 : !torch.vtensor<[?,?,?,?],f32> } } ```pull/3733/head
parent
d6feb2179c
commit
1259e8a00a
|
@ -8080,6 +8080,7 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
|
||||
|
@ -9672,6 +9673,7 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
|
||||
|
@ -9696,6 +9698,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,6 +30,24 @@ using namespace mlir::torch::Torch;
|
|||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult genericViewLikeFold(Attribute self, Type resultType) {
|
||||
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(self);
|
||||
if (!selfAttr)
|
||||
return nullptr;
|
||||
|
||||
auto resultTy = dyn_cast_or_null<ValueTensorType>(resultType);
|
||||
if (!resultTy || !resultTy.areAllSizesKnown())
|
||||
return nullptr;
|
||||
|
||||
if (selfAttr.isSplat()) {
|
||||
return SplatElementsAttr::get(resultTy.toBuiltinTensor(),
|
||||
selfAttr.getSplatValue<Attribute>());
|
||||
}
|
||||
return DenseElementsAttr::get(
|
||||
resultTy.toBuiltinTensor(),
|
||||
llvm::to_vector(selfAttr.getValues<Attribute>()));
|
||||
}
|
||||
|
||||
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
||||
Location loc, Value value,
|
||||
Type desiredType,
|
||||
|
@ -1049,6 +1067,8 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto genericFold = genericViewLikeFold(adaptor.getSelf(), getType()))
|
||||
return genericFold;
|
||||
auto inputType = dyn_cast<BaseTensorType>(getOperand(0).getType());
|
||||
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
|
@ -2236,10 +2256,22 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenFlattenUsingIntsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenFlattenUsingIntsOp::fold(FoldAdaptor adaptor) {
|
||||
return genericViewLikeFold(adaptor.getSelf(), getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenUnflattenIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenUnflattenIntOp::fold(FoldAdaptor adaptor) {
|
||||
return genericViewLikeFold(adaptor.getSelf(), getType());
|
||||
}
|
||||
|
||||
void AtenUnflattenIntOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
// if there are only two sizes and one of them is statically 1, then convert
|
||||
|
@ -3722,6 +3754,69 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
|
|||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenTransposeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenTransposeIntOp::fold(FoldAdaptor adaptor) {
|
||||
// first check for no-op
|
||||
IntegerAttr dim0 = dyn_cast_or_null<IntegerAttr>(adaptor.getDim0());
|
||||
IntegerAttr dim1 = dyn_cast_or_null<IntegerAttr>(adaptor.getDim1());
|
||||
if (!dim0 || !dim1)
|
||||
return nullptr;
|
||||
int64_t _dim0 = dim0.getValue().getSExtValue();
|
||||
int64_t _dim1 = dim1.getValue().getSExtValue();
|
||||
auto selfTy = dyn_cast<ValueTensorType>(getSelf().getType());
|
||||
if (!selfTy || !selfTy.hasSizes())
|
||||
return nullptr;
|
||||
int64_t rank = selfTy.getSizes().size();
|
||||
_dim0 = toPositiveDim(_dim0, rank);
|
||||
_dim1 = toPositiveDim(_dim1, rank);
|
||||
if (!isValidDim(_dim0, rank) || !isValidDim(_dim1, rank))
|
||||
return nullptr;
|
||||
// if dims are the same, return self
|
||||
if (_dim0 == _dim1)
|
||||
return getSelf();
|
||||
|
||||
// We set a maximum folding size of 16. This is a reasonable upper limit
|
||||
// for shape computations.
|
||||
constexpr int64_t kMaxFoldSize = 16;
|
||||
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
||||
if (!self || self.getNumElements() > kMaxFoldSize)
|
||||
return nullptr;
|
||||
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
||||
if (!selfTy || !resultTy || !selfTy.areAllSizesKnown())
|
||||
return nullptr;
|
||||
if (self.isSplat())
|
||||
return SplatElementsAttr::get(resultTy.toBuiltinTensor(),
|
||||
self.getSplatValue<Attribute>());
|
||||
|
||||
// TODO: add support for rank != 2
|
||||
if (rank != 2)
|
||||
return nullptr;
|
||||
|
||||
ArrayRef<int64_t> sizes = selfTy.getSizes();
|
||||
auto values = llvm::to_vector(self.getValues<Attribute>());
|
||||
// reordered[i] = Trans[i//sizes[0], i % sizes[0]] = Self[i % sizes[0],
|
||||
// i//sizes[0]] = values[(i % sizes[0])*sizes[1] + (i//sizes[0])].
|
||||
// e.g., Self size = [4,2]; Trans size = [2,4].
|
||||
// reindex(i) = (i % 4)*2 + (i // 4) .
|
||||
// i = 0 -> Trans[0,0] -> Self[0,0] -> 0 .
|
||||
// i = 1 -> Trans[0,1] -> Self[1,0] -> 2 .
|
||||
// i = 2 -> Trans[0,2] -> Self[2,0] -> 4 .
|
||||
// i = 3 -> Trans[0,3] -> Self[3,0] -> 6 .
|
||||
// i = 4 -> Trans[1,0] -> Self[0,1] -> 1 .
|
||||
// i = 5 -> Trans[1,1] -> Self[1,1] -> 3 .
|
||||
auto reindex = [&](int64_t i) {
|
||||
return (i % sizes[0]) * sizes[1] + (i / sizes[0]);
|
||||
};
|
||||
SmallVector<Attribute> reordered;
|
||||
for (int64_t i = 0; i < self.getNumElements(); i++) {
|
||||
reordered.push_back(values[reindex(i)]);
|
||||
}
|
||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), reordered);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenCatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3898,15 +3993,18 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
// Fold the slice if the output tensor is relatively small, currently
|
||||
// coded to 16:
|
||||
constexpr int64_t kMaxFold = 16;
|
||||
if (input && start && step && dim && count <= kMaxFold) {
|
||||
if (input && start && step && dim && end && count <= kMaxFold) {
|
||||
int64_t begin = start.getValue().getSExtValue();
|
||||
int64_t limit = end.getValue().getSExtValue();
|
||||
int64_t stride = step.getValue().getSExtValue();
|
||||
if (stride < 1)
|
||||
return nullptr;
|
||||
begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin;
|
||||
limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit;
|
||||
limit = limit < 0 ? -1 : limit;
|
||||
limit = std::min(limit, inType.getSizes()[dimInt]);
|
||||
bool validIterArgs =
|
||||
(stride > 0 && begin < limit) || (stride < 0 && begin > limit);
|
||||
assert(validIterArgs &&
|
||||
"aten.slice.Tensor iteration args are statically invalid.");
|
||||
|
||||
int64_t inputRank = inType.getSizes().size();
|
||||
llvm::SmallVector<int64_t> inputStrides(inputRank, 1);
|
||||
|
@ -3919,10 +4017,21 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) {
|
||||
if (currDim >= inputRank)
|
||||
return;
|
||||
size_t _begin = (currDim == dimInt) ? begin : 0;
|
||||
size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
|
||||
size_t _stride = (currDim == dimInt) ? stride : 1;
|
||||
for (size_t i = _begin; i < _limit; i += _stride) {
|
||||
int64_t _stride = (currDim == dimInt) ? stride : 1;
|
||||
int64_t _begin = (currDim == dimInt) ? begin : 0;
|
||||
int64_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
|
||||
// ensure that the limit is reached exactly (even with negative strides)
|
||||
// E.g., with begin = 0, limit = 10, stride = 3, we modify limit to be 11
|
||||
// = 10 + (10-0) % 3 .
|
||||
// E.g., with begin = 8, limit = -1, stride = -2, limit becomes -2 = -1 +
|
||||
// (-1-8) % (-2) - stride = -1 + 1 - 2 = -2 .
|
||||
// Note: cpp uses true math remainder "n % d = least positive int, x, such
|
||||
// that d divides (n - x)"
|
||||
int64_t limit_rem = (_limit - _begin) % _stride;
|
||||
limit_rem =
|
||||
(_stride > 0 || limit_rem == 0) ? limit_rem : limit_rem - _stride;
|
||||
_limit += limit_rem;
|
||||
for (int64_t i = _begin; std::abs(_limit - i) > 0; i += _stride) {
|
||||
if (currDim == inputRank - 1) {
|
||||
values.push_back(input.getValues<Attribute>()[currOffset + i]);
|
||||
}
|
||||
|
|
|
@ -2677,20 +2677,6 @@ ONNX_XFAIL_SET = {
|
|||
"MultinomialModule2D_basic",
|
||||
"MultinomialModule2D_F32",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"ReflectionPad1dModule2dInput_Right",
|
||||
"ReflectionPad1dModule2dInput_basic",
|
||||
"ReflectionPad1dModule3dInput_Left",
|
||||
"ReflectionPad1dModule3dInput_basic",
|
||||
"ReflectionPad2dModule_Bottom",
|
||||
"ReflectionPad2dModule_Left",
|
||||
"ReflectionPad2dModule_Right",
|
||||
"ReflectionPad2dModule_Top",
|
||||
"ReflectionPad2dModule_basic",
|
||||
"ReplicationPad2dModule_basic",
|
||||
"ReplicationPad2dModule_bottom0",
|
||||
"ReplicationPad2dModule_left0",
|
||||
"ReplicationPad2dModule_right0",
|
||||
"ReplicationPad2dModule_top0",
|
||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||
"SliceCopyNegative_Module_basic",
|
||||
"SliceCopyNonZeroDim_Module_basic",
|
||||
|
|
|
@ -684,7 +684,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)")
|
||||
emit("aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)")
|
||||
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True)
|
||||
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
|
||||
|
@ -769,9 +769,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
|
||||
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)", has_folder=True)
|
||||
emit(
|
||||
"aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", has_canonicalizer=True
|
||||
"aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
has_folder=True,
|
||||
)
|
||||
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
|
||||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||
|
|
|
@ -1682,6 +1682,82 @@ func.func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?
|
|||
return %1 : !torch.tensor<[?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$fold_splat(
|
||||
// CHECK: %[[SPLAT:.*]] = torch.vtensor.literal(dense<2> : tensor<2x4x1xsi64>) : !torch.vtensor<[2,4,1],si64>
|
||||
// CHECK: return %[[SPLAT]] : !torch.vtensor<[2,4,1],si64>
|
||||
func.func @torch.aten.view$fold_splat() -> !torch.vtensor<[2,4,1],si64> {
|
||||
%int4 = torch.constant.int 4
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.vtensor.literal(dense<2> : tensor<8xsi64>) : !torch.vtensor<[8],si64>
|
||||
%1 = torch.prim.ListConstruct %int2, %int4, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[2,4,1],si64>
|
||||
return %2 : !torch.vtensor<[2,4,1],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$fold_literal(
|
||||
// CHECK: %[[LITERAL:.*]] = torch.vtensor.literal(dense<[
|
||||
// CHECK-SAME: [
|
||||
// CHECK-SAME: [0, 1], [2, 3], [4, 5], [6, 7]]]> : tensor<1x4x2xsi64>) : !torch.vtensor<[1,4,2],si64>
|
||||
// CHECK: return %[[LITERAL]] : !torch.vtensor<[1,4,2],si64>
|
||||
func.func @torch.aten.view$fold_literal() -> !torch.vtensor<[1,4,2],si64> {
|
||||
%int4 = torch.constant.int 4
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.vtensor.literal(dense<[0,1,2,3,4,5,6,7]> : tensor<8xsi64>) : !torch.vtensor<[8],si64>
|
||||
%1 = torch.prim.ListConstruct %int1, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[1,4,2],si64>
|
||||
return %2 : !torch.vtensor<[1,4,2],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_literal(
|
||||
// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[
|
||||
// CHECK-SAME: [0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xsi64>) : !torch.vtensor<[2,4],si64>
|
||||
// CHECK: return %[[LIT]] : !torch.vtensor<[2,4],si64>
|
||||
func.func @torch.aten.transpose.int$fold_literal() -> !torch.vtensor<[2,4],si64> {
|
||||
%int-1 = torch.constant.int -1
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.vtensor.literal(dense<[[0,1],[2,3],[4,5],[6,7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
|
||||
%1 = torch.aten.transpose.int %0, %int-1, %int0 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4], si64>
|
||||
return %1 : !torch.vtensor<[2,4],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_noop(
|
||||
// CHECK: return %arg0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.transpose.int$fold_noop(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int-1 = torch.constant.int -1
|
||||
%int3 = torch.constant.int 3
|
||||
%0 = torch.aten.transpose.int %arg0, %int-1, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.slice.Tensor$flip_slice_fold(
|
||||
// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[
|
||||
// CHECK-SAME: [6, 7], [4, 5], [2, 3], [0, 1]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
|
||||
// CHECK: return %[[LIT]] : !torch.vtensor<[4,2],si64>
|
||||
func.func @torch.aten.slice.Tensor$flip_slice_fold() -> !torch.vtensor<[4,2],si64> {
|
||||
%int-9223372036854775807 = torch.constant.int -9223372036854775807
|
||||
%int-1 = torch.constant.int -1
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
|
||||
%1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
|
||||
return %1 : !torch.vtensor<[4,2],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.slice.Tensor$negative_two_stride_fold(
|
||||
// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[
|
||||
// CHECK-SAME: [6, 7], [2, 3]]> : tensor<2x2xsi64>) : !torch.vtensor<[2,2],si64>
|
||||
// CHECK: return %[[LIT]] : !torch.vtensor<[2,2],si64>
|
||||
func.func @torch.aten.slice.Tensor$negative_two_stride_fold() -> !torch.vtensor<[2,2],si64> {
|
||||
%int-5 = torch.constant.int -5
|
||||
%int-1 = torch.constant.int -1
|
||||
%int-2 = torch.constant.int -2
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
|
||||
%1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-5, %int-2 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],si64>
|
||||
return %1 : !torch.vtensor<[2,2],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.div.float$fold_zero_dividend(
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: return %[[CST0]] : !torch.float
|
||||
|
|
Loading…
Reference in New Issue