mirror of https://github.com/llvm/torch-mlir
Add Op for `torch.aten.unfold` (#3772)
# Description Implementation of the op for `torch.aten.unfold`: [TorchToLinalg Op Support #347](https://github.com/nod-ai/SHARK-ModelDev/issues/849) Documentation of op can be found here: [PyTorch Docs](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html) For this op, we apply a sliding window of some `size` along a single `dimension`, with `step` in between iterations. `Declaration: aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)` The resulting `unfolded` tensor modifies the shape of `dimension` to be equal to the number of blocks that the sliding windows extracts/inserts, with an additional dimension of `size` appended (the number of cols of the output tensor directly translates from the size of the sliding window). So if we had a tensor of rank 3 (A x B x C), with dimension = 1, size = 2 and step = 2: (A x B x C) |=> (A x (B - size) // step + 1 x C x size) After extracting the window from the input tensor, we insert the (1 x size) slice into the output tensor. We can make this simpler by mapping the output indices from the input indices, like they do in the official implementation: [PyTorch Code](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/lowering.py#L1694)pull/3776/head
parent
7830c00ca2
commit
d49eabb3fc
|
@ -13692,6 +13692,31 @@ def Torch_AtenViewCopyDtypeOp : Torch_Op<"aten.view_copy.dtype", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUnfoldOp : Torch_Op<"aten.unfold", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::unfold : (Tensor, int, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dimension,
|
||||
Torch_IntType:$size,
|
||||
Torch_IntType:$step
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUnfoldOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenUnfoldOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -2611,6 +2611,167 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenUnfoldOp : public OpConversionPattern<AtenUnfoldOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenUnfoldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto self = adaptor.getSelf();
|
||||
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||
|
||||
int64_t dimension;
|
||||
if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dimension))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dimension");
|
||||
}
|
||||
int64_t size;
|
||||
if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) {
|
||||
return rewriter.notifyMatchFailure(op, "only support constant int size");
|
||||
}
|
||||
int64_t step;
|
||||
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||||
return rewriter.notifyMatchFailure(op, "only support constant int step");
|
||||
}
|
||||
|
||||
if (step <= 0) {
|
||||
return rewriter.notifyMatchFailure(op, "step must be greater than zero.");
|
||||
}
|
||||
|
||||
int64_t selfRank = selfType.getRank();
|
||||
|
||||
// Zero-Rank case
|
||||
if (selfRank == 0) {
|
||||
// Empty tensor
|
||||
if (size == 0) {
|
||||
RankedTensorType resultType =
|
||||
RankedTensorType::get({0}, selfType.getElementType());
|
||||
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, resultType.getShape(), resultType.getElementType());
|
||||
|
||||
rewriter.replaceOp(op, emptyTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value unsqueezedSelf = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, RankedTensorType::get({1}, selfType.getElementType()), self,
|
||||
ArrayRef<ReassociationIndices>{});
|
||||
rewriter.replaceOp(op, unsqueezedSelf);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto shape = selfType.getShape();
|
||||
|
||||
if (dimension < 0) {
|
||||
dimension = toPositiveDim(dimension, selfRank);
|
||||
}
|
||||
if (!isValidDim(dimension, selfRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "dimension out of range");
|
||||
}
|
||||
|
||||
Value dimSize = rewriter.create<tensor::DimOp>(loc, self, dimension);
|
||||
|
||||
Value sizeValue = rewriter.create<arith::ConstantIndexOp>(loc, size);
|
||||
Value sizeCheck = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ule, sizeValue, dimSize);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, sizeCheck,
|
||||
rewriter.getStringAttr("size must be <= target dimension"));
|
||||
|
||||
/* Calculate output shape of unfold op:
|
||||
* https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
|
||||
* outputShape[dimension] is set to numBlocks, with size appended as an
|
||||
* additional dimension
|
||||
*/
|
||||
SmallVector<OpFoldResult> outputShape;
|
||||
for (int64_t i = 0; i < selfRank; i++) {
|
||||
if (i == dimension) {
|
||||
outputShape.push_back(getDynamicOrStaticNumBlocks(
|
||||
rewriter, loc, shape[dimension], dimSize, size, step));
|
||||
} else if (shape[i] == ShapedType::kDynamic) {
|
||||
outputShape.push_back(
|
||||
OpFoldResult(rewriter.create<tensor::DimOp>(loc, self, i)));
|
||||
} else {
|
||||
outputShape.push_back(rewriter.getIndexAttr(shape[i]));
|
||||
}
|
||||
}
|
||||
outputShape.push_back(rewriter.getIndexAttr(size));
|
||||
|
||||
// Empty tensor to insert values into
|
||||
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, outputShape, selfType.getElementType());
|
||||
|
||||
/**
|
||||
* Use reindexing to map output indices to input indices
|
||||
* i.e. In output of rank 3 case:
|
||||
* (i, j, k) => (i', j') where i' = i * step + k and j' = j
|
||||
* if dimension == 0
|
||||
* (i, j, k) => (i', j') where i' = i and j' = j * step + k
|
||||
* if dimension == 1
|
||||
*/
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
SmallVector<AffineExpr> outputExprs;
|
||||
for (int dim = 0; dim < selfRank; ++dim) {
|
||||
if (dim == dimension) {
|
||||
auto idxLast = getAffineDimExpr(selfRank, context);
|
||||
auto idxDimension = getAffineDimExpr(dimension, context);
|
||||
|
||||
AffineExpr dimIdx =
|
||||
idxLast + idxDimension * rewriter.getAffineConstantExpr(step);
|
||||
outputExprs.push_back(dimIdx);
|
||||
} else {
|
||||
outputExprs.push_back(getAffineDimExpr(dim, context));
|
||||
}
|
||||
}
|
||||
|
||||
int64_t outputRank = selfRank + 1;
|
||||
auto inputAffineMap = AffineMap::get(outputRank, 0, outputExprs, context);
|
||||
auto outputAffineMap =
|
||||
AffineMap::getMultiDimIdentityMap(outputRank, context);
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
outputRank, utils::IteratorType::parallel);
|
||||
|
||||
Value result =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outputTensor.getType(), self, outputTensor,
|
||||
ArrayRef({inputAffineMap, outputAffineMap}), iteratorTypes,
|
||||
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(nestedLoc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
OpFoldResult getDynamicOrStaticNumBlocks(OpBuilder &rewriter, Location loc,
|
||||
int64_t shapeDim, Value dimSize,
|
||||
int64_t size, int64_t step) const {
|
||||
/**
|
||||
* numBlocks = (shape[dimension] - size) // step + 1
|
||||
*/
|
||||
if (shapeDim == ShapedType::kDynamic) {
|
||||
Value numBlocksSubOp = rewriter.create<arith::SubIOp>(
|
||||
loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, size));
|
||||
Value numBlocksDivOp = rewriter.create<arith::DivUIOp>(
|
||||
loc, numBlocksSubOp,
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, step));
|
||||
Value numBlocks = rewriter.create<arith::AddIOp>(
|
||||
loc, rewriter.create<arith::ConstantIndexOp>(loc, 1), numBlocksDivOp);
|
||||
return OpFoldResult(numBlocks);
|
||||
}
|
||||
|
||||
int64_t staticNumBlocks = (shapeDim - size) / step + 1;
|
||||
return rewriter.getIndexAttr(staticNumBlocks); // Use static value
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
|
||||
public:
|
||||
|
@ -2679,7 +2840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
/*benefit=*/200);
|
||||
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
|
||||
/*benefit=*/100);
|
||||
|
||||
target.addIllegalOp<AtenUnfoldOp>();
|
||||
patterns.add<ConvertAtenUnfoldOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeOp>();
|
||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeDimOp>();
|
||||
|
|
|
@ -15588,6 +15588,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.unfold\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"size must be less than or equal to {}\"\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: size must be less than or equal to 1\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: \"\n"
|
||||
" %str_2 = torch.constant.str \"dimension out of range of {}\"\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
|
||||
" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %6 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
|
||||
" %7 = torch.aten.add.str %str_1, %6 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %7, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.le.int %arg2, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %5 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" %15 = torch.aten.add.int %arg1, %0 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %15 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.ge.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
|
||||
" %15 = torch.aten.lt.int %4, %0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %6 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
|
||||
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %7 = torch.aten.__getitem__.t %arg0, %4 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %8 = torch.aten.le.int %arg2, %7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %8 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.format(%str, %7) : !torch.str, !torch.int -> !torch.str\n"
|
||||
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %9 = torch.aten.sub.int %7, %arg2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %10 = torch.aten.floordiv.int %9, %arg3 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %11 = torch.aten.add.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %12 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %13 = torch.aten._set_item.t %12, %4, %11 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" %14 = torch.aten.append.t %12, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %12 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.unfold\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"";
|
||||
// clang-format on
|
||||
|
|
|
@ -278,7 +278,7 @@ bool Torch::isViewLikeOp(Operation *op) {
|
|||
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
||||
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
||||
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
|
||||
AtenPixelShuffleOp, AtenDiagonalOp>(op);
|
||||
AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
|
||||
}
|
||||
|
||||
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
||||
|
|
|
@ -915,6 +915,11 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
"SplitWithSizes_Module_basic",
|
||||
"Unfold_Module_basic",
|
||||
"Unfold_Module_Rank_4",
|
||||
"Unfold_Module_Rank_Zero_basic",
|
||||
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||
"Unfold_Module_Dynamic_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
||||
|
@ -3158,6 +3163,10 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"UnfoldModule_basic",
|
||||
"Unfold_Module_Rank_4",
|
||||
"Unfold_Module_Rank_Zero_basic",
|
||||
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||
"Unfold_Module_Dynamic_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||
|
|
|
@ -5559,7 +5559,45 @@ def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int,
|
|||
return torch.qint8
|
||||
return torch.qint32
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(), 0, 1, 1), # Rank Zero.
|
||||
Invocation(TensorOfShape(), 0, 0, 1), # Rank Zero, size of 0.
|
||||
Invocation(TensorOfShape(6, 4), 0, 2, 1), # Basic case.
|
||||
Invocation(TensorOfShape(6, 4, 2), 0, 2, 1), # Basic case.
|
||||
Invocation(TensorOfShape(6, 4), -1, 2, 1), # Negative Dimension.
|
||||
Invocation(TensorOfShape(6, 4, 2), -1, 2, 1), # Negative Dimension.
|
||||
])
|
||||
def aten〇unfold〡shape(self: List[int], dimension: int, size: int, step: int) -> List[int]:
|
||||
ndim = len(self)
|
||||
|
||||
# Rank zero tensor
|
||||
if ndim == 0:
|
||||
assert dimension == 0, f"dimension out of range of {ndim}"
|
||||
assert size <= 1, "size must be less than or equal to 1"
|
||||
return [size]
|
||||
|
||||
dim = dimension
|
||||
if dim < 0:
|
||||
dim += ndim
|
||||
|
||||
assert (dim >= 0 and dim < ndim), f"dimension out of range of {ndim}"
|
||||
|
||||
size_dim = self[dim]
|
||||
assert size <= size_dim, f"size must be less than or equal to {size_dim}"
|
||||
|
||||
num_blocks = (size_dim - size) // step + 1
|
||||
|
||||
out = upstream_shape_functions._copy(self)
|
||||
out[dim] = num_blocks
|
||||
out.append(size)
|
||||
return out
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, dimension=0, size=1, step=1)
|
||||
)
|
||||
def aten〇unfold〡dtype(self_rank_dtype: Tuple[int, int], dimension: int, size: int, step: int) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -992,6 +992,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::unfold : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)")
|
||||
emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)")
|
||||
|
|
|
@ -1648,3 +1648,103 @@ class Rot90NegativeEvenRotationsModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Rot90NegativeEvenRotationsModule())
|
||||
def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 5, 1, 7, 3))
|
||||
|
||||
|
||||
class Unfold_Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([6, 4], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(0, 2, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module())
|
||||
def Unfold_Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4))
|
||||
|
||||
|
||||
class Unfold_Module_Negative_Dim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([6, 4, 4, 4], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(-1, 2, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module_Negative_Dim())
|
||||
def Unfold_Module_Rank_4(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4, 4, 4))
|
||||
|
||||
|
||||
class Unfold_Module_Rank_Zero(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(0, 1, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
|
||||
def Unfold_Module_Rank_Zero_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand())
|
||||
|
||||
|
||||
class Unfold_Module_Rank_Zero_Size_Zero(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(0, 0, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero())
|
||||
def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand())
|
||||
|
||||
|
||||
class Unfold_Module_Dynamic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.unfold(1, 2, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Unfold_Module_Dynamic())
|
||||
def Unfold_Module_Dynamic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4, 4, 4))
|
||||
|
|
Loading…
Reference in New Issue