Implement lowering of torch.aten.norm.Scalar (#2899)

Closes
[nod-ai/SHARK-Turbine#365](https://github.com/nod-ai/SHARK-Turbine/issues/365)
pull/2955/head
ptrifunovic98 2024-02-26 17:46:56 +01:00 committed by GitHub
parent 89e02c195b
commit c5a1da1910
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 177 additions and 8 deletions

View File

@ -6325,6 +6325,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}
def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$p
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenNormScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}
def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -275,7 +275,8 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
elementType.getIntOrFloatBitWidth())));
}
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
if (isa<AtenAllDimOp>(op)) {
@ -341,6 +342,26 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
if (intType.isSigned())
return b.create<arith::MinSIOp>(loc, self, result);
}
} else if (isa<AtenNormScalarOp>(op)) {
// This creates payload for only the first of the two linalg.generic ops.
// TODO: Short-circuit operations if `p` is zero or one.
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
// TODO: Fix this part to support complex elements.
if (elem.getType().isa<mlir::ComplexType>()) {
op->emitError("lowering of complex input type for torch.aten.norm.Scalar "
"is currently unimplemented");
return nullptr;
}
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
auto abs = b.create<math::AbsFOp>(loc, self);
AtenNormScalarOp::Adaptor adaptor(operands);
Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, p);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenLinalgVectorNormOp>(op)) {
// This creates payload for only the first of the two linalg.generic ops.
// TODO: Short-circuit operations if `ord` is zero or one.
@ -433,7 +454,7 @@ private:
ConversionPatternRewriter &rewriter) const {
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp>(op)) {
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp, AtenNormScalarOp>(op)) {
opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
@ -484,10 +505,12 @@ private:
return err ? Value{} : powOp;
}
FailureOr<Value> createSecondReductionForVectorNormOp(
Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp,
Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
template <typename TOp>
FailureOr<Value>
createSecondReductionForNormOp(Location loc, Type elemType, TOp op,
Value ordOp, Value firstReduction,
const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
// Cast `ord` to float so that we can readily pass it math.powf.
Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType);
@ -544,13 +567,15 @@ private:
LogicalResult
validateReductionElementType(Operation *op, Type elemType,
ConversionPatternRewriter &rewriter) const {
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op)) &&
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op)) &&
!elemType.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops");
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() &&
elemType.getIntOrFloatBitWidth() == 8)
return rewriter.notifyMatchFailure(op, "uint8 is not supported");
// No checks for all other reduction operations
return success();
}
@ -587,11 +612,22 @@ public:
return rewriter.notifyMatchFailure(
op, "failed to create linalg.generic operation for reduction");
// If this is aten.norm.Scalar op, then we need to generate another
// linalg.generic op that references the first linalg.generic op.
if (isa<AtenNormScalarOp>(op)) {
AtenNormScalarOp::Adaptor adaptor(operands);
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter);
if (failed(secondReduceOp))
return secondReduceOp;
reduceOp = *secondReduceOp;
}
// If this is aten.linalg_vector_norm op, then we need to generate another
// linalg.generic op that references the first linalg.generic op.
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op)) {
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
FailureOr<Value> secondReduceOp = createSecondReductionForVectorNormOp(
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter);
if (failed(secondReduceOp))
return secondReduceOp;
@ -627,6 +663,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenMaxOp>();
target.addIllegalOp<AtenMinOp>();
target.addIllegalOp<AtenAllDimOp>();
target.addIllegalOp<AtenNormScalarOp>();
target.addIllegalOp<AtenLinalgVectorNormOp>();
target.addIllegalOp<AtenFrobeniusNormDimOp>();
patterns.add<ConvertReductionOp>(typeConverter, context);

View File

@ -3767,6 +3767,42 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// AtenNormScalarOp
//===----------------------------------------------------------------------===//
LogicalResult AtenNormScalarOp::verify() {
// Verificaion of input type for torch.aten.norm.Scalar.
// Per PyTorch docs, only float and complex types are valid for norm
// operation.
auto inTensor = getSelf().getType().cast<BaseTensorType>();
// If no dtype is specified, it will default to a float one.
if (!inTensor.hasDtype()) {
return success();
}
auto inTensorDtype = inTensor.getDtype();
// Check if dtype is one of those supported by norm operation.
// ComplexType will match any torch complex types, but each float must be
// checked individually.
if (!inTensorDtype.isa<mlir::ComplexType, mlir::Float16Type,
mlir::Float32Type, mlir::Float64Type>()) {
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< inTensorDtype;
}
return success();
}
//===----------------------------------------------------------------------===//
// AtenPermuteOp
//===----------------------------------------------------------------------===//
LogicalResult AtenPermuteOp::verify() {
// Verification of the permute op for input & output dimensions with

View File

@ -9339,6 +9339,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
" %1 = torch.derefine %none : !torch.none to !torch.any\n"
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %false, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.derefine %arg2 : !torch.list<int> to !torch.optional<list<int>>\n"
@ -12038,6 +12046,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %int5 = torch.constant.int 5\n"
" %int8 = torch.constant.int 8\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int5 : !torch.int\n"
" } else {\n"
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional<int>, %arg2: !torch.optional<Device>, %arg3: !torch.bool) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"

View File

@ -1667,6 +1667,7 @@ ONNX_XFAIL_SET = {
"NllLossModule_ignore_index_out_of_bounds_basic",
"NllLossModule_mean_basic",
"NllLossModule_sum_basic",
"NormScalarModule_basic",
"NormScalarOptDimKeepDimModule_basic",
"NormScalarOptDimModule_basic",
"NormalFunctionalModule_basic",

View File

@ -1722,6 +1722,9 @@ def atenlinalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Opti
def atenfrobenius_normdim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
def atennormScalar〡shape(self: List[int], p: float = 2) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, None, False, None)
def atennormScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
@ -3924,6 +3927,21 @@ def atenlinalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni
return dtype
return atenstd〡dtype(self_rank_dtype)
@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
def atennormScalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex] = 2) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)
# The following check is added because atenstd〡dtype
# does not handle complex32 transformation to float,
# so it is done manually (torch.half == torch.float16).
# Should possibly be added to atenstd〡dtype.
if self_dtype == torch.complex32:
return torch.half
return atenstd〡dtype(self_rank_dtype)
@check_dtype_function([Invocation(0.0),
Invocation(0.0, dtype=torch.int32),
Invocation(0.0, dtype=torch.float16),

View File

@ -449,6 +449,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
emit(
"aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)"
)

View File

@ -1100,6 +1100,25 @@ def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils):
# ==============================================================================
class NormScalarModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.p = 3.0
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.norm(a, self.p)
@register_test_case(module_factory=lambda: NormScalarModule())
def NormScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class NormScalarOptDimModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()