diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 1804cdb85..4556cd87a 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -264,6 +264,7 @@ STABLEHLO_PASS_SET = { "Mv_basic", "NativeLayerNormModule4D_basic", "NativeLayerNormModule_basic", + "OneHotModule_basic", "PrimsConvertElementTypeModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceSumDimIntListElementTypeBoolModule_basic", @@ -935,4 +936,5 @@ LTC_XFAIL_SET = { "PrimsSqueezeEmptyDimensionsModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", + "OneHotModule_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ad0e7dbe9..3b1f1cfe1 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6225,6 +6225,30 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ }]; } +def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::one_hot : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$num_classes + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenOneHotOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenOneHotOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index d46c61c61..ae109c283 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6435,6 +6435,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: getting num_classes from tensor contents is not supported\"\n" +" %int-1 = torch.constant.int -1\n" +" %0 = torch.aten.ne.int %arg1, %int-1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list\n" +" %2 = torch.aten.add.t %arg0, %1 : !torch.list, !torch.list -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" " %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" " %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 213e92f5a..e2acf5caa 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4165,6 +4165,65 @@ public: }; } // namespace +namespace { +class DecomposeAtenOneHotOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenOneHotOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + + Value input = op.getSelf(); + auto inputType = input.getType().cast(); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + int64_t inputRank = inputType.getSizes().size(); + int64_t numClasses; + if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses))) + return rewriter.notifyMatchFailure( + op, "unimplemented: num_classes must be constant"); + Value none = rewriter.create(loc); + Value falseValue = rewriter.create(loc, false); + + // arange tensor + auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); + auto arangeType = + ValueTensorType::get(context, llvm::ArrayRef(numClasses), si64Type); + Value arangeTensor = rewriter.create( + loc, arangeType, op.getNumClasses(), /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + + // unsqueeze input + llvm::SmallVector unsqueezeShape(inputType.getSizes()); + unsqueezeShape.push_back(1); + auto unsqueezeType = + ValueTensorType::get(context, unsqueezeShape, si64Type); + Value unsqueezeTensor = rewriter.create( + loc, unsqueezeType, input, + rewriter.create(loc, + rewriter.getI64IntegerAttr(inputRank))); + + // compare + auto eqType = ValueTensorType::get( + context, op.getType().cast().getSizes(), + IntegerType::get(context, 1)); + Value eqTensor = rewriter.create( + loc, eqType, unsqueezeTensor, arangeTensor); + + // convert to si64 + Value si64TypeValue = + Torch::getDtypeIntValueForType(rewriter, loc, si64Type); + Value result = rewriter.create( + loc, op.getType(), eqTensor, si64TypeValue, /*non_blocking=*/falseValue, + /*copy=*/falseValue, /*memory_format=*/none); + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -4325,6 +4384,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 6b6a84de8..a17c1577c 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index b1b7c0ded..70df807cb 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -659,7 +659,7 @@ void TypeAnalysis::visitOperation(Operation *op, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp, AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp, - PrimsSqueezeOp>(op)) { + PrimsSqueezeOp, AtenOneHotOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index c21ba644a..676ca933a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -356,6 +356,12 @@ def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: return upstream_shape_functions.argmax(self, dim, keepdim) +# TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor, +# making it impossible to add support for it using the current design of the shape library. +def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]: + assert num_classes != -1, "getting num_classes from tensor contents is not supported" + return self + [num_classes] + def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]: return upstream_shape_functions.argmax(self, dim, keepdim) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a80589157..3132e20aa 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -457,6 +457,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") + emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)") emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 6908b13e1..8a00d3091 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -3564,3 +3564,22 @@ class PrimsViewOfZeroRankModule(torch.nn.Module): @register_test_case(module_factory=lambda: PrimsViewOfZeroRankModule()) def PrimsViewOfZeroRankModule_basic(module, tu: TestUtils): module.forward(tu.rand()) + + +# ============================================================================== + + +class OneHotModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.long, True)]) + def forward(self, x): + return torch.nn.functional.one_hot(x, num_classes=5) + + +@register_test_case(module_factory=lambda: OneHotModule()) +def OneHotModule_basic(module, tu: TestUtils): + module.forward(tu.randint(10, high=5))