Implement lowering of torch.aten.renorm (#3388)

Closes
[nod-ai/SHARK-Turbine/issues/689](https://github.com/nod-ai/SHARK-Turbine/issues/689)

---------

Co-authored-by: Branko Trifkovic <branko.trifkovic@syrmia.com>
pull/3470/head
Branko Trifkovic 2024-06-17 19:40:57 +02:00 committed by GitHub
parent 59bade3376
commit 676fa8cc09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 382 additions and 0 deletions

View File

@ -6657,6 +6657,33 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}
def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$p,
Torch_IntType:$dim,
AnyTorchScalarType:$maxnorm
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRenormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenRenormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasVerifier = 1;
}
def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -4655,6 +4655,80 @@ LogicalResult AtenNormScalarOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// AtenRenormOp
//===----------------------------------------------------------------------===//
LogicalResult AtenRenormOp::verify() {
auto selfType = cast<BaseTensorType>(getSelf().getType());
if (!selfType.hasDtype() || !selfType.hasSizes())
return success();
auto inShape = selfType.getSizes();
int64_t selfRank = inShape.size();
auto selfDtype = selfType.getDtype();
if (!isa<mlir::Float16Type, mlir::BFloat16Type, mlir::Float32Type,
mlir::Float64Type, mlir::ComplexType>(selfDtype))
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< selfDtype;
// According to the Pytoch documentation tensor need to be at least rank 2
if (selfRank <= 1)
return emitOpError("renorm: input needs at least 2 dimensions, got ")
<< selfRank << " dimensions";
// Check if argument p is valid
auto pType = getP().getType();
if (isa<mlir::ComplexType>(pType))
return emitOpError("renorm: p must be real-valued");
// The argument 'p' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'p' is within the correct
// range
int64_t pInt = 1;
double_t pDouble = 1;
if (!matchPattern(getP(), m_TorchConstantInt(&pInt)) &&
!matchPattern(getP(), m_TorchConstantFloat(&pDouble)))
return success();
if (pInt <= 0 || pDouble <= 0)
return emitOpError("renorm: non-positive norm not supported");
// Check if argument maxnorm is valid
auto maxnormType = getMaxnorm().getType();
if (isa<mlir::ComplexType>(maxnormType))
return emitOpError("renorm: maxnorm must be real-valued");
// The argument 'maxnorm' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'maxnorm' is within the
// correct range
int64_t maxnormInt = 0;
double_t maxnormDouble = 0;
if (!matchPattern(getMaxnorm(), m_TorchConstantInt(&maxnormInt)) &&
!matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble)))
return success();
if (maxnormInt < 0 || maxnormDouble < 0)
return emitOpError("renorm: expected maxnorm to be >= 0");
// Get the dimension
int64_t dim;
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
return success();
// check if is dim is in the correct range
if (dim >= selfRank || dim < -selfRank)
return emitOpError("Dimension out of range (expected to be in range of [")
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
return success();
}
//===----------------------------------------------------------------------===//
// AtenPermuteOp
//===----------------------------------------------------------------------===//

View File

@ -10119,6 +10119,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.renorm\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
@ -13162,6 +13165,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %int5 = torch.constant.int 5\n"

View File

@ -2069,6 +2069,145 @@ public:
};
} // namespace
// https://github.com/pytorch/pytorch/blob/9dec41b684a4284c4e052e295314c23f0f942fec/torch/_refs/__init__.py#L3229
// Decompose aten.renorm into: linalg_vector_norm
namespace {
class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRenormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value dim = op.getDim();
Value p = op.getP();
Value maxnorm = op.getMaxnorm();
// Prepare all necessary variables
auto ndim = getTensorRank(self);
auto resType = cast<BaseTensorType>(self.getType());
if (!resType.hasDtype() || !resType.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"result should have dtype and sizes");
}
Type dtype = resType.getDtype();
if (isa<mlir::ComplexType>(dtype)) {
return rewriter.notifyMatchFailure(
op, "lowering of aten.renorm for complex inputs dtype is "
"currently unimplemented");
}
SmallVector<int64_t> inputSize(resType.getSizes());
// Convert dim from Value to int
int64_t dimInt;
if (!matchPattern(dim, m_TorchConstantInt(&dimInt)))
return rewriter.notifyMatchFailure(op,
"Unimplemented: dim not constant int");
// Define all constants
Value cstTrue = rewriter.create<ConstantBoolOp>(loc, true);
Value cstZero = rewriter.create<Torch::ConstantIntOp>(loc, 0);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(loc, 1);
Value cstNone = rewriter.create<ConstantNoneOp>(loc);
// Arragne reduce_dims tensor (vector), [0, 1, ... , dim-1, dim+1, ... ,
// ndim-1]
llvm::SmallVector<Value> reduceDimsVector;
for (u_int64_t i = 0; i < ndim; i++) {
if (i == (u_int64_t)dimInt)
continue;
Value constI = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
reduceDimsVector.push_back(constI);
}
Value reduceDimsList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
reduceDimsVector);
// Make output shape for linalg.vector_norm operation
SmallVector<Value> inputSizeValue;
for (u_int64_t i = 0; i < inputSize.size(); i++) {
if (i != (u_int64_t)dimInt)
inputSize[i] = 1;
inputSizeValue.push_back(
rewriter.create<Torch::ConstantIntOp>(loc, inputSize[i]));
}
// Prepare arguments for linalg.vector_norm
Value dtypeValue;
Type vectorNormOutType;
if (isa<mlir::Float16Type, mlir::BFloat16Type>(dtype)) {
dtype = cast<Type>(rewriter.getF32Type());
dtypeValue = getDtypeIntValueForType(rewriter, loc, dtype);
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
} else {
dtypeValue = cstNone;
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
}
auto norm = rewriter.create<AtenLinalgVectorNormOp>(
loc, vectorNormOutType, self, p, reduceDimsList, cstTrue, dtypeValue);
// Define epsiolon constant 10^-7
mlir::FloatType f64Type = rewriter.getF64Type();
Value epsValue = rewriter.create<ConstantFloatOp>(
loc, rewriter.getFloatAttr(f64Type, 1e-7));
Value normPlusEps = rewriter.create<AtenAddScalarOp>(
loc, vectorNormOutType, norm, epsValue, cstOne);
Value maxnormTensorValue = rewriter.create<AtenFullLikeOp>(
loc, normPlusEps.getType(), normPlusEps, maxnorm, cstNone, cstNone,
cstNone, cstNone, cstNone);
// Divide maxnorm and normPlusEps
auto divideMaxnormAndNorm = rewriter.create<AtenDivTensorOp>(
loc, vectorNormOutType, maxnormTensorValue, normPlusEps);
// Next few lines corespond to this pythorch code: norm_factor =
// torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
auto boolTensorType = rewriter.getType<ValueTensorType>(
cast<BaseTensorType>(vectorNormOutType).getOptionalSizes(),
rewriter.getI1Type());
Value greaterThanMaxnorm =
rewriter.create<AtenGtScalarOp>(loc, boolTensorType, norm, maxnorm);
Value cstOnetensor = rewriter.create<AtenFullLikeOp>(
loc, normPlusEps.getType(), normPlusEps, cstOne, cstNone, cstNone,
cstNone, cstNone, cstNone);
auto normFactor = rewriter.create<AtenWhereSelfOp>(
loc, vectorNormOutType, greaterThanMaxnorm, divideMaxnormAndNorm,
cstOnetensor);
// Converte norm_factor to input dtype
Value normFactorFinal = rewriter.create<PrimsConvertElementTypeOp>(
loc, resType.getWithSizesAndDtype(inputSize, resType.getDtype()),
normFactor, getDtypeIntValueForType(rewriter, loc, resType.getDtype()));
// Multiply input tensor with norm factor
auto output = rewriter.create<AtenMulTensorOp>(loc, self.getType(), self,
normFactorFinal);
rewriter.replaceOpWithNewOp<AtenContiguousOp>(op, self.getType(), output,
/*memory_format*/ cstZero);
return success();
}
};
} // namespace
// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select,
// aten.add.Tensor and aten.mull.Tensor. See
// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70.
@ -8081,6 +8220,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);

View File

@ -402,6 +402,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNormScalarOptDimOp>();
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenRenormOp>();
target.addIllegalOp<AtenLinalgCrossOp>();
target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenTOp>();

View File

@ -1473,6 +1473,9 @@ STABLEHLO_PASS_SET = {
"ElementwiseLogSigmoidModule_basic",
"ElementwiseHardshrinkStaticModule_basic",
"ElementwiseSoftshrinkStaticModule_basic",
"RenormModuleFloat16_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}
STABLEHLO_CRASHING_SET = set()
@ -1949,6 +1952,8 @@ TOSA_PASS_SET = {
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}
MAKE_FX_TOSA_PASS_SET = (
@ -1982,6 +1987,8 @@ MAKE_FX_TOSA_PASS_SET = (
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeFromOtherTensor_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}
) - {
### Test failing in make_fx_tosa but not in tosa
@ -2695,6 +2702,11 @@ ONNX_XFAIL_SET = {
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# Error: 'aten::renorm' to ONNX opset version 17 is not supported.
"RenormModuleFloat16_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
"RenormModuleFloat32DynamicDims_basic",
# Failure - unknown
"BernoulliModule_basic",
"Conv_Transpose1dModule_basic",

View File

@ -1998,6 +1998,9 @@ def atenlinalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim
def atenfrobenius_normdim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
def atenrenorm〡shape(self: List[int], p: float, dim: int, maxnorm: float) -> List[int]:
return self
def atennormScalar〡shape(self: List[int], p: float = 2) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, None, False, None)
@ -4416,6 +4419,20 @@ def atenlinalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
return dtype
return atenstd〡dtype(self_rank_dtype)
@check_dtype_function(
_check_tensors_with_the_same_dtype(
tensor_shapes=[(3,3)],
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64},
p=1,
dim=0,
maxnorm=5)
)
def atenrenorm〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex], dim: int, maxnorm: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)
return self_dtype
@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1,

View File

@ -587,6 +587,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")
emit(

View File

@ -633,3 +633,96 @@ class AtenInstanceNormModule(torch.nn.Module):
@register_test_case(module_factory=lambda: AtenInstanceNormModule())
def AtenInstanceNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2))
# ==============================================================================
class RenormModuleFloat32(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = 2
self.dim = 1
self.maxnorm = 10
@export
@annotate_args(
[
None,
([3, 3], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm)
@register_test_case(module_factory=lambda: RenormModuleFloat32())
def RenormModuleFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3))
class RenormModuleFloat16(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = 2.1
self.dim = 1
self.maxnorm = 10
@export
@annotate_args(
[
None,
([3, 4, 5], torch.float16, True),
]
)
def forward(self, x):
return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm)
@register_test_case(module_factory=lambda: RenormModuleFloat16())
def RenormModuleFloat16_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float16))
class RenormModuleFloat32NegativeDim(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = 2.3
self.dim = -1
self.maxnorm = 5.2
@export
@annotate_args(
[
None,
([1, 4, 5, 2], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm)
@register_test_case(module_factory=lambda: RenormModuleFloat32NegativeDim())
def RenormModuleFloat32NegativeDim_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 4, 5, 2).to(torch.float32))
class RenormModuleFloat32DynamicDims(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = 2
self.dim = 1
self.maxnorm = 10
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm)
@register_test_case(module_factory=lambda: RenormModuleFloat32DynamicDims())
def RenormModuleFloat32DynamicDims_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 3))