mirror of https://github.com/llvm/torch-mlir
Implement lowering of torch.aten.renorm (#3388)
Closes [nod-ai/SHARK-Turbine/issues/689](https://github.com/nod-ai/SHARK-Turbine/issues/689) --------- Co-authored-by: Branko Trifkovic <branko.trifkovic@syrmia.com>pull/3470/head
parent
59bade3376
commit
676fa8cc09
|
@ -6657,6 +6657,33 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$p,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchScalarType:$maxnorm
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRenormOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenRenormOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -4655,6 +4655,80 @@ LogicalResult AtenNormScalarOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenRenormOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtenRenormOp::verify() {
|
||||
|
||||
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
||||
|
||||
if (!selfType.hasDtype() || !selfType.hasSizes())
|
||||
return success();
|
||||
|
||||
auto inShape = selfType.getSizes();
|
||||
int64_t selfRank = inShape.size();
|
||||
auto selfDtype = selfType.getDtype();
|
||||
|
||||
if (!isa<mlir::Float16Type, mlir::BFloat16Type, mlir::Float32Type,
|
||||
mlir::Float64Type, mlir::ComplexType>(selfDtype))
|
||||
return emitOpError(
|
||||
"expected a float or complex type for input tensor, but got ")
|
||||
<< selfDtype;
|
||||
|
||||
// According to the Pytoch documentation tensor need to be at least rank 2
|
||||
if (selfRank <= 1)
|
||||
return emitOpError("renorm: input needs at least 2 dimensions, got ")
|
||||
<< selfRank << " dimensions";
|
||||
|
||||
// Check if argument p is valid
|
||||
auto pType = getP().getType();
|
||||
|
||||
if (isa<mlir::ComplexType>(pType))
|
||||
return emitOpError("renorm: p must be real-valued");
|
||||
|
||||
// The argument 'p' can be either an integer or a floating-point number,
|
||||
// so we need to consider both options and check if 'p' is within the correct
|
||||
// range
|
||||
int64_t pInt = 1;
|
||||
double_t pDouble = 1;
|
||||
if (!matchPattern(getP(), m_TorchConstantInt(&pInt)) &&
|
||||
!matchPattern(getP(), m_TorchConstantFloat(&pDouble)))
|
||||
return success();
|
||||
|
||||
if (pInt <= 0 || pDouble <= 0)
|
||||
return emitOpError("renorm: non-positive norm not supported");
|
||||
|
||||
// Check if argument maxnorm is valid
|
||||
auto maxnormType = getMaxnorm().getType();
|
||||
if (isa<mlir::ComplexType>(maxnormType))
|
||||
return emitOpError("renorm: maxnorm must be real-valued");
|
||||
|
||||
// The argument 'maxnorm' can be either an integer or a floating-point number,
|
||||
// so we need to consider both options and check if 'maxnorm' is within the
|
||||
// correct range
|
||||
int64_t maxnormInt = 0;
|
||||
double_t maxnormDouble = 0;
|
||||
if (!matchPattern(getMaxnorm(), m_TorchConstantInt(&maxnormInt)) &&
|
||||
!matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble)))
|
||||
return success();
|
||||
|
||||
if (maxnormInt < 0 || maxnormDouble < 0)
|
||||
return emitOpError("renorm: expected maxnorm to be >= 0");
|
||||
|
||||
// Get the dimension
|
||||
int64_t dim;
|
||||
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
|
||||
return success();
|
||||
|
||||
// check if is dim is in the correct range
|
||||
if (dim >= selfRank || dim < -selfRank)
|
||||
return emitOpError("Dimension out of range (expected to be in range of [")
|
||||
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenPermuteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -10119,6 +10119,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.renorm\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %none = torch.constant.none\n"
|
||||
|
@ -13162,6 +13165,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
|
|
|
@ -2069,6 +2069,145 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// https://github.com/pytorch/pytorch/blob/9dec41b684a4284c4e052e295314c23f0f942fec/torch/_refs/__init__.py#L3229
|
||||
// Decompose aten.renorm into: linalg_vector_norm
|
||||
namespace {
|
||||
class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenRenormOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
Value dim = op.getDim();
|
||||
Value p = op.getP();
|
||||
Value maxnorm = op.getMaxnorm();
|
||||
|
||||
// Prepare all necessary variables
|
||||
auto ndim = getTensorRank(self);
|
||||
auto resType = cast<BaseTensorType>(self.getType());
|
||||
|
||||
if (!resType.hasDtype() || !resType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"result should have dtype and sizes");
|
||||
}
|
||||
|
||||
Type dtype = resType.getDtype();
|
||||
if (isa<mlir::ComplexType>(dtype)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "lowering of aten.renorm for complex inputs dtype is "
|
||||
"currently unimplemented");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> inputSize(resType.getSizes());
|
||||
|
||||
// Convert dim from Value to int
|
||||
int64_t dimInt;
|
||||
if (!matchPattern(dim, m_TorchConstantInt(&dimInt)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unimplemented: dim not constant int");
|
||||
|
||||
// Define all constants
|
||||
Value cstTrue = rewriter.create<ConstantBoolOp>(loc, true);
|
||||
Value cstZero = rewriter.create<Torch::ConstantIntOp>(loc, 0);
|
||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(loc, 1);
|
||||
Value cstNone = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
||||
// Arragne reduce_dims tensor (vector), [0, 1, ... , dim-1, dim+1, ... ,
|
||||
// ndim-1]
|
||||
llvm::SmallVector<Value> reduceDimsVector;
|
||||
for (u_int64_t i = 0; i < ndim; i++) {
|
||||
if (i == (u_int64_t)dimInt)
|
||||
continue;
|
||||
|
||||
Value constI = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
|
||||
reduceDimsVector.push_back(constI);
|
||||
}
|
||||
|
||||
Value reduceDimsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc,
|
||||
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||
reduceDimsVector);
|
||||
|
||||
// Make output shape for linalg.vector_norm operation
|
||||
SmallVector<Value> inputSizeValue;
|
||||
for (u_int64_t i = 0; i < inputSize.size(); i++) {
|
||||
if (i != (u_int64_t)dimInt)
|
||||
inputSize[i] = 1;
|
||||
|
||||
inputSizeValue.push_back(
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, inputSize[i]));
|
||||
}
|
||||
|
||||
// Prepare arguments for linalg.vector_norm
|
||||
Value dtypeValue;
|
||||
Type vectorNormOutType;
|
||||
|
||||
if (isa<mlir::Float16Type, mlir::BFloat16Type>(dtype)) {
|
||||
dtype = cast<Type>(rewriter.getF32Type());
|
||||
dtypeValue = getDtypeIntValueForType(rewriter, loc, dtype);
|
||||
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
|
||||
} else {
|
||||
dtypeValue = cstNone;
|
||||
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
|
||||
}
|
||||
|
||||
auto norm = rewriter.create<AtenLinalgVectorNormOp>(
|
||||
loc, vectorNormOutType, self, p, reduceDimsList, cstTrue, dtypeValue);
|
||||
|
||||
// Define epsiolon constant 10^-7
|
||||
mlir::FloatType f64Type = rewriter.getF64Type();
|
||||
Value epsValue = rewriter.create<ConstantFloatOp>(
|
||||
loc, rewriter.getFloatAttr(f64Type, 1e-7));
|
||||
|
||||
Value normPlusEps = rewriter.create<AtenAddScalarOp>(
|
||||
loc, vectorNormOutType, norm, epsValue, cstOne);
|
||||
|
||||
Value maxnormTensorValue = rewriter.create<AtenFullLikeOp>(
|
||||
loc, normPlusEps.getType(), normPlusEps, maxnorm, cstNone, cstNone,
|
||||
cstNone, cstNone, cstNone);
|
||||
|
||||
// Divide maxnorm and normPlusEps
|
||||
auto divideMaxnormAndNorm = rewriter.create<AtenDivTensorOp>(
|
||||
loc, vectorNormOutType, maxnormTensorValue, normPlusEps);
|
||||
|
||||
// Next few lines corespond to this pythorch code: norm_factor =
|
||||
// torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
|
||||
auto boolTensorType = rewriter.getType<ValueTensorType>(
|
||||
cast<BaseTensorType>(vectorNormOutType).getOptionalSizes(),
|
||||
rewriter.getI1Type());
|
||||
|
||||
Value greaterThanMaxnorm =
|
||||
rewriter.create<AtenGtScalarOp>(loc, boolTensorType, norm, maxnorm);
|
||||
|
||||
Value cstOnetensor = rewriter.create<AtenFullLikeOp>(
|
||||
loc, normPlusEps.getType(), normPlusEps, cstOne, cstNone, cstNone,
|
||||
cstNone, cstNone, cstNone);
|
||||
|
||||
auto normFactor = rewriter.create<AtenWhereSelfOp>(
|
||||
loc, vectorNormOutType, greaterThanMaxnorm, divideMaxnormAndNorm,
|
||||
cstOnetensor);
|
||||
|
||||
// Converte norm_factor to input dtype
|
||||
Value normFactorFinal = rewriter.create<PrimsConvertElementTypeOp>(
|
||||
loc, resType.getWithSizesAndDtype(inputSize, resType.getDtype()),
|
||||
normFactor, getDtypeIntValueForType(rewriter, loc, resType.getDtype()));
|
||||
|
||||
// Multiply input tensor with norm factor
|
||||
auto output = rewriter.create<AtenMulTensorOp>(loc, self.getType(), self,
|
||||
normFactorFinal);
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenContiguousOp>(op, self.getType(), output,
|
||||
/*memory_format*/ cstZero);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select,
|
||||
// aten.add.Tensor and aten.mull.Tensor. See
|
||||
// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70.
|
||||
|
@ -8081,6 +8220,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
|
||||
|
|
|
@ -402,6 +402,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenNormScalarOptDimOp>();
|
||||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
target.addIllegalOp<AtenMvOp>();
|
||||
target.addIllegalOp<AtenRenormOp>();
|
||||
target.addIllegalOp<AtenLinalgCrossOp>();
|
||||
target.addIllegalOp<AtenPixelShuffleOp>();
|
||||
target.addIllegalOp<AtenTOp>();
|
||||
|
|
|
@ -1473,6 +1473,9 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseLogSigmoidModule_basic",
|
||||
"ElementwiseHardshrinkStaticModule_basic",
|
||||
"ElementwiseSoftshrinkStaticModule_basic",
|
||||
"RenormModuleFloat16_basic",
|
||||
"RenormModuleFloat32NegativeDim_basic",
|
||||
"RenormModuleFloat32_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_CRASHING_SET = set()
|
||||
|
@ -1949,6 +1952,8 @@ TOSA_PASS_SET = {
|
|||
"LinspaceOneSizeModule_basic",
|
||||
"LinspaceTwoSizeModule_basic",
|
||||
"TorchPrimLoopForLikeTensorArgModule_basic",
|
||||
"RenormModuleFloat32NegativeDim_basic",
|
||||
"RenormModuleFloat32_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_PASS_SET = (
|
||||
|
@ -1982,6 +1987,8 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"RenormModuleFloat32NegativeDim_basic",
|
||||
"RenormModuleFloat32_basic",
|
||||
}
|
||||
) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
|
@ -2695,6 +2702,11 @@ ONNX_XFAIL_SET = {
|
|||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
||||
# RuntimeError: unsupported input type: Device
|
||||
"PrimsIotaModule_basic",
|
||||
# Error: 'aten::renorm' to ONNX opset version 17 is not supported.
|
||||
"RenormModuleFloat16_basic",
|
||||
"RenormModuleFloat32NegativeDim_basic",
|
||||
"RenormModuleFloat32_basic",
|
||||
"RenormModuleFloat32DynamicDims_basic",
|
||||
# Failure - unknown
|
||||
"BernoulliModule_basic",
|
||||
"Conv_Transpose1dModule_basic",
|
||||
|
|
|
@ -1998,6 +1998,9 @@ def aten〇linalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim
|
|||
def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
|
||||
|
||||
def aten〇renorm〡shape(self: List[int], p: float, dim: int, maxnorm: float) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, None, False, None)
|
||||
|
||||
|
@ -4416,6 +4419,20 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
|
|||
return dtype
|
||||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(
|
||||
tensor_shapes=[(3,3)],
|
||||
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64},
|
||||
p=1,
|
||||
dim=0,
|
||||
maxnorm=5)
|
||||
)
|
||||
def aten〇renorm〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex], dim: int, maxnorm: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert not is_integer_dtype(self_dtype)
|
||||
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(
|
||||
num_of_tensors=1,
|
||||
|
|
|
@ -587,6 +587,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True)
|
||||
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
|
||||
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")
|
||||
emit(
|
||||
|
|
|
@ -633,3 +633,96 @@ class AtenInstanceNormModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: AtenInstanceNormModule())
|
||||
def AtenInstanceNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class RenormModuleFloat32(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p = 2
|
||||
self.dim = 1
|
||||
self.maxnorm = 10
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 3], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RenormModuleFloat32())
|
||||
def RenormModuleFloat32_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 3))
|
||||
|
||||
|
||||
class RenormModuleFloat16(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p = 2.1
|
||||
self.dim = 1
|
||||
self.maxnorm = 10
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 4, 5], torch.float16, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RenormModuleFloat16())
|
||||
def RenormModuleFloat16_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float16))
|
||||
|
||||
|
||||
class RenormModuleFloat32NegativeDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p = 2.3
|
||||
self.dim = -1
|
||||
self.maxnorm = 5.2
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 4, 5, 2], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RenormModuleFloat32NegativeDim())
|
||||
def RenormModuleFloat32NegativeDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 4, 5, 2).to(torch.float32))
|
||||
|
||||
|
||||
class RenormModuleFloat32DynamicDims(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p = 2
|
||||
self.dim = 1
|
||||
self.maxnorm = 10
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RenormModuleFloat32DynamicDims())
|
||||
def RenormModuleFloat32DynamicDims_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 3))
|
||||
|
|
Loading…
Reference in New Issue