diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index d2f66e62b..96459a3a0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -2707,7 +2707,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value onehot = rewriter.create( binder.getLoc(), onehotTy, indices, depth); - for (int i = valuesTy.getSizes().size(); i > axis; ++i) { + for (int i = indicesTy.getSizes().size(); i > axis; --i) { std::swap(onehotShape[i - 1], onehotShape[i]); Value iv0 = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); @@ -2716,7 +2716,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( onehotTy = rewriter.getType(onehotShape, i32Ty); - onehot = rewriter.create(loc, resultType, + onehot = rewriter.create(loc, onehotTy, onehot, iv1, iv0); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index abb84dff4..12130e0d9 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8301,6 +8301,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { op, "input tensor should have known sizes."); int64_t inputRank = inputType.getSizes().size(); int64_t numClasses = Torch::kUnknownSize; + auto resultType = cast(op.getType()); matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses)); Value none = rewriter.create(loc); @@ -8313,14 +8314,15 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { /*device=*/none, /*pin_memory=*/none); // unsqueeze input - llvm::SmallVector unsqueezeShape(inputType.getSizes()); - unsqueezeShape.push_back(1); - auto unsqueezeType = - ValueTensorType::get(context, unsqueezeShape, si64Type); - Value unsqueezeTensor = rewriter.create( - loc, unsqueezeType, input, - rewriter.create(loc, - rewriter.getI64IntegerAttr(inputRank))); + Value rankV = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank)); + auto unsqueeze = Torch::unsqueezeTensor(rewriter, op, input, rankV); + if (failed(unsqueeze)) + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + + Value unsqueezeTensor = + convertTensorToDtype(rewriter, loc, *unsqueeze, si64Type); // compare auto eqType = ValueTensorType::get( @@ -8330,7 +8332,8 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { loc, eqType, unsqueezeTensor, arangeTensor); // convert to si64 - Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type); + Value result = + convertTensorToDtype(rewriter, loc, eqTensor, resultType.getDtype()); rewriter.replaceOp(op, result); return success(); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 3bc9b2b4b..c879cefc5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1480,7 +1480,7 @@ func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3 // ----- // CHECK-LABEL: func.func @test_onehot_negative_indices -func.func @test_onehot_negative_indices(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,10],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_onehot_negative_indices(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[10,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[ITEM:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[INT:.*]] = torch.aten.Int.Scalar %[[ITEM]] : !torch.float -> !torch.int @@ -1494,15 +1494,18 @@ func.func @test_onehot_negative_indices(%arg0: !torch.vtensor<[3],si64>, %arg1: // CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg2, %[[C0]], %[[C1]]: !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[ONEHOT:.*]] = torch.aten.one_hot %[[WHERE]], %[[INT]] : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si32> + // CHECK: %[[D0:.+]] = torch.constant.int 1 + // CHECK: %[[D1:.+]] = torch.constant.int 0 + // CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %[[ONEHOT]], %[[D1]], %[[D0]] // CHECK: %[[C11:.*]] = torch.constant.int 11 // CHECK: %[[NONE_0:.*]] = torch.constant.none // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %[[ONEHOT]], %[[C11]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,?],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,?],i1> - // CHECK: %[[RESULT:.*]] = torch.aten.where.Scalar %[[DTYPE]], %[[ITEM_1]], %[[ITEM_0]] : !torch.vtensor<[3,?],i1>, !torch.float, !torch.float -> !torch.vtensor<[3,10],f32> - // CHECK: return %[[RESULT]] : !torch.vtensor<[3,10],f32> + // CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %[[TRANS]], %[[C11]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[?,3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,3],i1> + // CHECK: %[[RESULT:.*]] = torch.aten.where.Scalar %[[DTYPE]], %[[ITEM_1]], %[[ITEM_0]] : !torch.vtensor<[?,3],i1>, !torch.float, !torch.float -> !torch.vtensor<[10,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[10,3],f32> %none = torch.constant.none - %0 = torch.operator "onnx.OneHot"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,10],f32> - return %0 : !torch.vtensor<[3,10],f32> + %0 = torch.operator "onnx.OneHot"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[10,3],f32> + return %0 : !torch.vtensor<[10,3],f32> } // -----