mirror of https://github.com/llvm/torch-mlir
[LINALG] Implement lowering of torch.aten.rot90 (#3551)
parent
d4b5e05ac1
commit
70d5730c87
|
@ -9322,6 +9322,32 @@ def Torch_Aten_WeightNormInterfaceOp : Torch_Op<"aten._weight_norm_interface", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRot90Op : Torch_Op<"aten.rot90", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::rot90 : (Tensor, int, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$k,
|
||||
AnyTorchListOfTorchIntType:$dims
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRot90Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenRot90Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -5471,3 +5471,36 @@ LogicalResult AtenTrilIndicesOp::verify() {
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenRot90Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtenRot90Op::verify() {
|
||||
// Check rotation dimensions.
|
||||
SmallVector<Value> dims;
|
||||
if (!getListConstructElements(getDims(), dims))
|
||||
return success();
|
||||
|
||||
if (dims.size() != 2)
|
||||
return emitOpError("expected total rotation dims == 2, but got dims = ")
|
||||
<< dims.size();
|
||||
|
||||
// Check a rank of the input tensor.
|
||||
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
||||
if (!selfType.hasSizes())
|
||||
return success();
|
||||
|
||||
auto selfShape = selfType.getSizes();
|
||||
int64_t selfRank = selfShape.size();
|
||||
|
||||
if (selfRank < 2)
|
||||
return emitOpError("expected total dims >= 2, but got total dims = ")
|
||||
<< selfRank;
|
||||
|
||||
if (dims[0] == dims[1])
|
||||
return emitOpError(
|
||||
"expected rotation dims to be different, but got dim0 = ")
|
||||
<< dims[0] << " and dim1 = " << dims[1];
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -8985,6 +8985,64 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
|
||||
" return %13 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.rot90\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %str = torch.constant.str \"expected total rotation dims == 2, but got dims = {}\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: \"\n"
|
||||
" %str_1 = torch.constant.str \"expected total dims >= 2 but got {}\"\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %9 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %10 = torch.aten.format(%str_1, %9) : !torch.str, !torch.int -> !torch.str\n"
|
||||
" %11 = torch.aten.add.str %str_0, %10 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %11, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||
" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %9 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||
" %10 = torch.aten.format(%str, %9) : !torch.str, !torch.int -> !torch.str\n"
|
||||
" %11 = torch.aten.add.str %str_0, %10 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %11, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.remainder.int %arg1, %int4 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.add.int %4, %int4 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %6 = torch.aten.remainder.int %5, %int4 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %9 = torch.aten.eq.int %6, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %9 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %8 -> () {\n"
|
||||
" %9 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %10 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.__getitem__.t %arg0, %11 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten._set_item.t %arg0, %13, %10 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten._set_item.t %arg0, %15, %12 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -14795,6 +14853,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rot90\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -5500,6 +5500,72 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.rot90
|
||||
// github:
|
||||
// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L3830
|
||||
namespace {
|
||||
class DecomposeAtenRot90Op : public OpRewritePattern<AtenRot90Op> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenRot90Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
|
||||
// Convert dims from Value to SmallVector.
|
||||
SmallVector<Value> dims;
|
||||
if (!getListConstructElements(op.getDims(), dims))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dims not list of Scalar");
|
||||
|
||||
// Convert k from Value to int
|
||||
int64_t k;
|
||||
if (!matchPattern(op.getK(), m_TorchConstantInt(&k)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unimplemented: k not constant int");
|
||||
|
||||
k = (k % 4 + 4) %
|
||||
4; // This is equal to python code k = k % 4, because python and c++
|
||||
// have different implementation for operand %.
|
||||
|
||||
if (k == 1) {
|
||||
Value flipDimList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc,
|
||||
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||
ArrayRef{dims[1]});
|
||||
|
||||
Value flip =
|
||||
rewriter.create<AtenFlipOp>(loc, self.getType(), self, flipDimList);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenTransposeIntOp>(
|
||||
op, op.getType(), flip, dims[0], dims[1]);
|
||||
} else if (k == 2) {
|
||||
rewriter.replaceOpWithNewOp<AtenFlipOp>(op, op.getType(), self,
|
||||
op.getDims());
|
||||
} else if (k == 3) {
|
||||
Value flipDimList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc,
|
||||
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||
ArrayRef{dims[0]});
|
||||
|
||||
Value flip =
|
||||
rewriter.create<AtenFlipOp>(loc, self.getType(), self, flipDimList);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenTransposeIntOp>(
|
||||
op, op.getType(), flip, dims[0], dims[1]);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<AtenCloneOp>(
|
||||
op, op.getType(), self,
|
||||
/*memory_format=*/
|
||||
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(0)));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.std.correction to sqrt(var.correction(x))
|
||||
namespace {
|
||||
class DecomposeAtenStdCorrectionOp
|
||||
|
@ -9603,6 +9669,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRot90Op>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitWithSizesOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
|
||||
|
|
|
@ -406,6 +406,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
target.addIllegalOp<AtenMvOp>();
|
||||
target.addIllegalOp<AtenRenormOp>();
|
||||
target.addIllegalOp<AtenRot90Op>();
|
||||
target.addIllegalOp<AtenLinalgCrossOp>();
|
||||
target.addIllegalOp<Aten_LinalgDetOp>();
|
||||
target.addIllegalOp<AtenLinalgSlogdetOp>();
|
||||
|
|
|
@ -1326,6 +1326,10 @@ STABLEHLO_PASS_SET = {
|
|||
"ReturnThreeTensorFloat32_basic",
|
||||
"ReturnTwoTensorF32I64_basic",
|
||||
"RollModule_basic",
|
||||
"Rot90BasicModule_basic",
|
||||
"Rot90MultipleRotationsModule_basic",
|
||||
"Rot90NegativeEvenRotationsModule_basic",
|
||||
"Rot90NegativeOddRotationsModule_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
|
@ -2693,6 +2697,7 @@ ONNX_XFAIL_SET = {
|
|||
"ReshapeAliasCollapseModule_basic",
|
||||
"ReshapeAliasExpandModule_basic",
|
||||
"ReshapeExpandModule_basic",
|
||||
"Rot90DynamicDimsModule_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScalarImplicitIntModule_basic",
|
||||
|
@ -2865,6 +2870,11 @@ ONNX_XFAIL_SET = {
|
|||
"RenormModuleFloat32NegativeDim_basic",
|
||||
"RenormModuleFloat32_basic",
|
||||
"RenormModuleFloat32DynamicDims_basic",
|
||||
"Rot90BasicModule_basic",
|
||||
"Rot90DynamicDymsModule_basic",
|
||||
"Rot90MultipleRotationsModule_basic",
|
||||
"Rot90NegativeEvenRotationsModule_basic",
|
||||
"Rot90NegativeOddRotationsModule_basic",
|
||||
# Failure - unknown
|
||||
"BernoulliModule_basic",
|
||||
"Conv_Transpose1dModule_basic",
|
||||
|
|
|
@ -1352,6 +1352,25 @@ def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: L
|
|||
def aten〇diag_embed〡shape(self: List[int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> List[int]:
|
||||
return _diag_embed_shape_helper(self, offset, dim1, dim2)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
||||
Invocation(TensorOfShape(5, 3, 4), k = 5, dims=(1, 2,)), # multiple times rotation
|
||||
Invocation(TensorOfShape(3, 5, 2), k = -2), # neagtive direction, remainder=2
|
||||
Invocation(TensorOfShape(7, 2, 6, 3), k = -5), # neagtive direction, remainder=3
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), dims=(0,)), # total lenght of the dims is < 2
|
||||
ErrorInvocation(TensorOfShape(2)), # the input is one-dimensional
|
||||
])
|
||||
def aten〇rot90〡shape(self: List[int], k: int = 1, dims: List[int] = (0, 1,)) -> List[int]:
|
||||
assert len(self) >= 2, "expected total dims >= 2 but got {}".format(len(self))
|
||||
assert len(dims) == 2, "expected total rotation dims == 2, but got dims = {}".format(len(dims))
|
||||
|
||||
k = (k % 4 + 4) % 4 # equal to k % 4, but 'k % 4' cannot handle negative values for k.
|
||||
|
||||
if k == 1 or k == 3:
|
||||
self[dims[0]], self[dims[1]] = self[dims[1]], self[dims[0]]
|
||||
|
||||
return self
|
||||
|
||||
def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -5095,6 +5114,10 @@ def aten〇diag_embed〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0,
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
def aten〇rot90〡dtype(self_rank_dtype: Tuple[int, int], k: int = 1, dims: List[int] = (0, 1,)) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) +
|
||||
|
|
|
@ -747,6 +747,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
)
|
||||
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)")
|
||||
emit("aten::rot90 : (Tensor, int, int[]) -> (Tensor)", has_verifier=True)
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
|
|
|
@ -1530,61 +1530,121 @@ def Atleast1dModule1dInput_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class Atleast2dModule0dInput(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
class Rot90BasicModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float32, True),
|
||||
([4, 5], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.atleast_2d(x)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.rot90(
|
||||
a,
|
||||
k=1,
|
||||
dims=(
|
||||
0,
|
||||
1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Atleast2dModule0dInput())
|
||||
def Atleast2dModule0dInput_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand())
|
||||
@register_test_case(module_factory=lambda: Rot90BasicModule())
|
||||
def Rot90BasicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5))
|
||||
|
||||
|
||||
class Atleast2dModule1dInput(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
class Rot90DynamicDimsModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.atleast_2d(x)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.rot90(
|
||||
a,
|
||||
k=1,
|
||||
dims=(
|
||||
0,
|
||||
1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Atleast2dModule1dInput())
|
||||
def Atleast2dModule1dInput_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4))
|
||||
@register_test_case(module_factory=lambda: Rot90DynamicDimsModule())
|
||||
def Rot90DynamicDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 2, 4))
|
||||
|
||||
|
||||
class Atleast2dModule2dInput(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
class Rot90MultipleRotationsModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4, 4], torch.float32, True),
|
||||
([7, 4, 6], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.atleast_2d(x)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.rot90(
|
||||
a,
|
||||
k=6,
|
||||
dims=(
|
||||
1,
|
||||
2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Atleast2dModule2dInput())
|
||||
def Atleast2dModule2dInput_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4))
|
||||
@register_test_case(module_factory=lambda: Rot90MultipleRotationsModule())
|
||||
def Rot90MultipleRotationsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(7, 4, 6))
|
||||
|
||||
|
||||
class Rot90NegativeOddRotationsModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([7, 4, 6, 5, 3], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.rot90(
|
||||
a,
|
||||
k=-5,
|
||||
dims=(
|
||||
1,
|
||||
2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Rot90NegativeOddRotationsModule())
|
||||
def Rot90NegativeOddRotationsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(7, 4, 6, 5, 3))
|
||||
|
||||
|
||||
class Rot90NegativeEvenRotationsModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([6, 5, 1, 7, 3], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.rot90(
|
||||
a,
|
||||
k=-6,
|
||||
dims=(
|
||||
1,
|
||||
-2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Rot90NegativeEvenRotationsModule())
|
||||
def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 5, 1, 7, 3))
|
||||
|
|
Loading…
Reference in New Issue