mirror of https://github.com/llvm/torch-mlir
Implement lowering of torch.aten.norm.Scalar (#2899)
Closes [nod-ai/SHARK-Turbine#365](https://github.com/nod-ai/SHARK-Turbine/issues/365)pull/2955/head
parent
89e02c195b
commit
c5a1da1910
|
@ -6325,6 +6325,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$p
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenNormScalarOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -275,7 +275,8 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
|||
elementType.getIntOrFloatBitWidth())));
|
||||
}
|
||||
|
||||
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
|
||||
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
|
||||
isa<AtenNormScalarOp>(op))
|
||||
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
|
||||
if (isa<AtenAllDimOp>(op)) {
|
||||
|
@ -341,6 +342,26 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
|||
if (intType.isSigned())
|
||||
return b.create<arith::MinSIOp>(loc, self, result);
|
||||
}
|
||||
} else if (isa<AtenNormScalarOp>(op)) {
|
||||
// This creates payload for only the first of the two linalg.generic ops.
|
||||
// TODO: Short-circuit operations if `p` is zero or one.
|
||||
Value elem = payloadArgs[0];
|
||||
Value result = payloadArgs[1];
|
||||
|
||||
// TODO: Fix this part to support complex elements.
|
||||
if (elem.getType().isa<mlir::ComplexType>()) {
|
||||
op->emitError("lowering of complex input type for torch.aten.norm.Scalar "
|
||||
"is currently unimplemented");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
||||
|
||||
auto abs = b.create<math::AbsFOp>(loc, self);
|
||||
AtenNormScalarOp::Adaptor adaptor(operands);
|
||||
Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType);
|
||||
auto pow = b.create<math::PowFOp>(loc, abs, p);
|
||||
return b.create<arith::AddFOp>(loc, pow, result);
|
||||
} else if (isa<AtenLinalgVectorNormOp>(op)) {
|
||||
// This creates payload for only the first of the two linalg.generic ops.
|
||||
// TODO: Short-circuit operations if `ord` is zero or one.
|
||||
|
@ -433,7 +454,7 @@ private:
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
||||
|
||||
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp>(op)) {
|
||||
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp, AtenNormScalarOp>(op)) {
|
||||
opInfo.tensorOperand = operands[0];
|
||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
|
@ -484,10 +505,12 @@ private:
|
|||
return err ? Value{} : powOp;
|
||||
}
|
||||
|
||||
FailureOr<Value> createSecondReductionForVectorNormOp(
|
||||
Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp,
|
||||
Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
template <typename TOp>
|
||||
FailureOr<Value>
|
||||
createSecondReductionForNormOp(Location loc, Type elemType, TOp op,
|
||||
Value ordOp, Value firstReduction,
|
||||
const torch_to_linalg::ReductionOpInfo &opInfo,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Cast `ord` to float so that we can readily pass it math.powf.
|
||||
Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType);
|
||||
|
||||
|
@ -544,13 +567,15 @@ private:
|
|||
LogicalResult
|
||||
validateReductionElementType(Operation *op, Type elemType,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op)) &&
|
||||
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
|
||||
isa<AtenNormScalarOp>(op)) &&
|
||||
!elemType.isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only float types are valid for vector norm ops");
|
||||
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() &&
|
||||
elemType.getIntOrFloatBitWidth() == 8)
|
||||
return rewriter.notifyMatchFailure(op, "uint8 is not supported");
|
||||
|
||||
// No checks for all other reduction operations
|
||||
return success();
|
||||
}
|
||||
|
@ -587,11 +612,22 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to create linalg.generic operation for reduction");
|
||||
|
||||
// If this is aten.norm.Scalar op, then we need to generate another
|
||||
// linalg.generic op that references the first linalg.generic op.
|
||||
if (isa<AtenNormScalarOp>(op)) {
|
||||
AtenNormScalarOp::Adaptor adaptor(operands);
|
||||
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
|
||||
loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter);
|
||||
if (failed(secondReduceOp))
|
||||
return secondReduceOp;
|
||||
reduceOp = *secondReduceOp;
|
||||
}
|
||||
|
||||
// If this is aten.linalg_vector_norm op, then we need to generate another
|
||||
// linalg.generic op that references the first linalg.generic op.
|
||||
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op)) {
|
||||
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
|
||||
FailureOr<Value> secondReduceOp = createSecondReductionForVectorNormOp(
|
||||
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
|
||||
loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter);
|
||||
if (failed(secondReduceOp))
|
||||
return secondReduceOp;
|
||||
|
@ -627,6 +663,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
|
|||
target.addIllegalOp<AtenMaxOp>();
|
||||
target.addIllegalOp<AtenMinOp>();
|
||||
target.addIllegalOp<AtenAllDimOp>();
|
||||
target.addIllegalOp<AtenNormScalarOp>();
|
||||
target.addIllegalOp<AtenLinalgVectorNormOp>();
|
||||
target.addIllegalOp<AtenFrobeniusNormDimOp>();
|
||||
patterns.add<ConvertReductionOp>(typeConverter, context);
|
||||
|
|
|
@ -3767,6 +3767,42 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenNormScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtenNormScalarOp::verify() {
|
||||
|
||||
// Verificaion of input type for torch.aten.norm.Scalar.
|
||||
// Per PyTorch docs, only float and complex types are valid for norm
|
||||
// operation.
|
||||
|
||||
auto inTensor = getSelf().getType().cast<BaseTensorType>();
|
||||
|
||||
// If no dtype is specified, it will default to a float one.
|
||||
if (!inTensor.hasDtype()) {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto inTensorDtype = inTensor.getDtype();
|
||||
|
||||
// Check if dtype is one of those supported by norm operation.
|
||||
// ComplexType will match any torch complex types, but each float must be
|
||||
// checked individually.
|
||||
if (!inTensorDtype.isa<mlir::ComplexType, mlir::Float16Type,
|
||||
mlir::Float32Type, mlir::Float64Type>()) {
|
||||
return emitOpError(
|
||||
"expected a float or complex type for input tensor, but got ")
|
||||
<< inTensorDtype;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenPermuteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtenPermuteOp::verify() {
|
||||
|
||||
// Verification of the permute op for input & output dimensions with
|
||||
|
|
|
@ -9339,6 +9339,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
|
||||
" %1 = torch.derefine %none : !torch.none to !torch.any\n"
|
||||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %false, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.derefine %arg2 : !torch.list<int> to !torch.optional<list<int>>\n"
|
||||
|
@ -12038,6 +12046,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %int8 = torch.constant.int 8\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int5 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
|
||||
" torch.prim.If.yield %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional<int>, %arg2: !torch.optional<Device>, %arg3: !torch.bool) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
|
|
|
@ -1667,6 +1667,7 @@ ONNX_XFAIL_SET = {
|
|||
"NllLossModule_ignore_index_out_of_bounds_basic",
|
||||
"NllLossModule_mean_basic",
|
||||
"NllLossModule_sum_basic",
|
||||
"NormScalarModule_basic",
|
||||
"NormScalarOptDimKeepDimModule_basic",
|
||||
"NormScalarOptDimModule_basic",
|
||||
"NormalFunctionalModule_basic",
|
||||
|
|
|
@ -1722,6 +1722,9 @@ def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Opti
|
|||
def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
|
||||
|
||||
def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, None, False, None)
|
||||
|
||||
def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
|
||||
|
||||
|
@ -3924,6 +3927,21 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni
|
|||
return dtype
|
||||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(
|
||||
num_of_tensors=1,
|
||||
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
|
||||
def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex] = 2) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert not is_integer_dtype(self_dtype)
|
||||
# The following check is added because aten〇std〡dtype
|
||||
# does not handle complex32 transformation to float,
|
||||
# so it is done manually (torch.half == torch.float16).
|
||||
# Should possibly be added to aten〇std〡dtype.
|
||||
if self_dtype == torch.complex32:
|
||||
return torch.half
|
||||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
@check_dtype_function([Invocation(0.0),
|
||||
Invocation(0.0, dtype=torch.int32),
|
||||
Invocation(0.0, dtype=torch.float16),
|
||||
|
|
|
@ -449,6 +449,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
|
||||
emit(
|
||||
"aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)"
|
||||
)
|
||||
|
|
|
@ -1100,6 +1100,25 @@ def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class NormScalarModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.p = 3.0
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.norm(a, self.p)
|
||||
|
||||
@register_test_case(module_factory=lambda: NormScalarModule())
|
||||
def NormScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NormScalarOptDimModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue