diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0becb6686..99d00e287 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6203,6 +6203,37 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ }]; } +def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$use_input_stats, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 50d4fae53..12b7ab559 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -392,6 +392,37 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "InstanceNormalization", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + float eps; + + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType) || operands.size() != 3 || + binder.f32FloatAttr(eps, "epsilon", 1e-05f)) { + return failure(); + } + Value none = rewriter.create(binder.getLoc()); + Value boolTrue = + rewriter.create(binder.getLoc(), true); + Value boolFalse = + rewriter.create(binder.getLoc(), false); + auto epsValue = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(eps)); + + auto momentum = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, /* input */ operands[0], + /* weight */ operands[1], + /* bias */ operands[2], /* running mean */ none, + /* running var */ none, + /* use input stats */ boolTrue, momentum, epsValue, + /* cudnn enabled */ boolFalse); + return success(); + }); patterns.onOp( "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 29c943042..39813da66 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8784,6 +8784,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" " return %3 : !torch.tuple, list, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.instance_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9643,6 +9647,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.instance_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index abd716c56..f9c1f63b5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3962,6 +3962,151 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenInstanceNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenInstanceNormOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + auto context = op.getContext(); + + auto inputTy = op.getInput().getType().cast(); + int64_t inputRank = inputTy.getSizes().size(); + auto reduceDimInts = + llvm::SmallVector({inputRank - 2, inputRank - 1}); + + SmallVector reducedShape(inputTy.getSizes()); + reducedShape[inputRank - 1] = 1; + reducedShape[inputRank - 2] = 1; + + Type dtype = inputTy.getOptionalDtype(); + Type reducedTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(reducedShape), dtype); + + auto sizeListType = ListType::get(IntType::get(context)); + SmallVector reduceDimVals; + reduceDimVals.reserve(reduceDimInts.size()); + std::transform(reduceDimInts.begin(), reduceDimInts.end(), + std::back_inserter(reduceDimVals), [&](int64_t d) { + return rewriter.create( + loc, rewriter.getI64IntegerAttr(d)); + }); + Value reduceDimList = + rewriter.create(loc, sizeListType, reduceDimVals); + Value cstTrue = rewriter.create(loc, true); + Value none = rewriter.create(loc); + + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + // mean(x) + Value inputMean = rewriter.create( + loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); + + // x - mean(x) + Value inputMeanExpanded = + rewriter.create(loc, inputTy, inputMean, op.getInput()); + Value inputSubMean = rewriter.create( + loc, inputTy, op.getInput(), inputMeanExpanded, one); + // (x - mean(x))^2 + Value inputSubMeanSquare = rewriter.create( + loc, inputTy, inputSubMean, inputSubMean); + + Value variancesum = rewriter.create( + loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue, + /*dtype=*/none); + + Value hw = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] * + inputTy.getSizes()[inputRank - 2])); + Value inputVar = + rewriter.create(loc, reducedTy, variancesum, hw); + + // rsqrt(var(x) + eps) + Value inputVarPlusEps = rewriter.create( + loc, reducedTy, inputVar, op.getEps(), one); + Value inputRsqrtVar = + rewriter.create(loc, reducedTy, inputVarPlusEps); + + // (x - mean(x)) * rsqrt(var(x) + eps) + Value inputRsqrtVarExpanded = rewriter.create( + loc, inputTy, inputRsqrtVar, op.getInput()); + Value inputNormalized = rewriter.create( + loc, inputTy, inputSubMean, inputRsqrtVarExpanded); + Value out = rewriter.create( + loc, op.getResult().getType(), inputNormalized); + + Value weight = op.getWeight(); + auto weightTy = weight.getType().cast(); + dtype = weightTy.getOptionalDtype(); + + SmallVector weightShape(weightTy.getSizes()); + SmallVector newWeightShape; + newWeightShape.push_back(1); + newWeightShape.append(weightShape); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Type newWeightTy = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, zero); + + Value two = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + newWeightShape.push_back(1); + newWeightTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, two); + + Value three = rewriter.create( + loc, rewriter.getI64IntegerAttr(3)); + newWeightShape.push_back(1); + newWeightTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, three); + + Value weightExpanded = + rewriter.create(loc, inputTy, weight, op.getInput()); + + Value bias = op.getBias(); + auto biasTy = bias.getType().cast(); + dtype = biasTy.getOptionalDtype(); + + SmallVector biasShape(biasTy.getSizes()); + SmallVector newBiasShape; + newBiasShape.push_back(1); + newBiasShape.append(biasShape); + + Type newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, zero); + + newBiasShape.push_back(1); + newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, two); + + newBiasShape.push_back(1); + newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, three); + + Value biasExpanded = + rewriter.create(loc, inputTy, bias, op.getInput()); + + out = rewriter.create(loc, out.getType(), out, + weightExpanded); + out = rewriter.create(loc, out.getType(), out, + biasExpanded, one); + + rewriter.replaceOp(op, out); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNativeLayerNormOp : public OpRewritePattern { @@ -6733,6 +6878,7 @@ public: DecomposeAtenAddCLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAddCLikeOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 306b2446a..5d3488b11 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -409,6 +409,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, }); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a1cee9037..0cb088874 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -233,6 +233,7 @@ TORCHDYNAMO_XFAIL_SET = { # END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' + "AtenInstanceNormModule_basic", "BatchNorm1DModule_basic", "BatchNorm1DWith2DInputModule_basic", "BatchNorm2DModule_basic", @@ -898,6 +899,7 @@ TOSA_PASS_SET = { "AtenEyeModuleFalsePinMemory_basic", "AtenEyeModuleFloat2D_basic", "AtenRoundIntModule_basic", + "AtenInstanceNormModule_basic", "AtenToDeviceModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", @@ -1306,6 +1308,8 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { "Conv2dNoPaddingModule_basic", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", + + "AtenInstanceNormModule_basic", } LTC_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index c014808af..a856ac026 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1415,6 +1415,9 @@ def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optiona def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]: return upstream_shape_functions.unary(input), [N, group], [N, group] +def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: + return upstream_shape_functions.unary(input) + def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) @@ -2048,6 +2051,11 @@ def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r assert not is_integer_dtype(input_dtype) return input_dtype, input_dtype, input_dtype +# device is not supported hence unable to check the dtype function +def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 65e9f44c1..1dc8585d7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -437,6 +437,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) + emit( + "aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" + ) emit( "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index 3b17f516f..56821fb69 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -489,3 +489,21 @@ class LayerNormNormalizeOverAllDimsModule(torch.nn.Module): def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 3)) +class AtenInstanceNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 1, 3], torch.float32, True), + ([2], torch.float32, True), + ([2], torch.float32, True) + ]) + def forward(self, x, w, b): + return torch.ops.aten.instance_norm(x, w, b, None, + None, True, 0.0, 1e-05, False) + +@register_test_case(module_factory=lambda: AtenInstanceNormModule()) +def AtenInstanceNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2)) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c5b28156a..8729e7f2d 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -603,6 +603,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_instancenorm + func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.instance_norm %arg0, %arg1, %arg2, %none, %none, %true, %float0.000000e00, %float9.999990e-06, %false : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,2,1,3],f32> + %0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> + return %0 : !torch.vtensor<[1,2,1,3],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>