mirror of https://github.com/llvm/torch-mlir
Added support for aten::norm.ScalarOpt_dim (#1774)
* Added support for aten::norm.ScalarOpt_dim * Disable NormalizeModule_basic for linalgpull/1789/head
parent
a897c49803
commit
c8b867b876
|
@ -4502,6 +4502,32 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalScalarType:$p,
|
||||
AnyTorchListOfTorchIntType:$dim,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenNormScalarOptDimOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenNormScalarOptDimOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -371,6 +371,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
# to the backend contract.
|
||||
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||
"QuantizedMLP_basic",
|
||||
"NormalizeModule_basic",
|
||||
}
|
||||
|
||||
def register_all_tests():
|
||||
|
|
|
@ -261,6 +261,25 @@ def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class NormalizeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 3], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.normalize(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NormalizeModule())
|
||||
def NormalizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NativeLayerNormModule4D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue