diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 5046b8859..cf14fc026 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1778,15 +1778,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( llvm::SmallVector{1}, valuesTy.getDtype()); bool valuesAreInt = isa(valuesTy.getDtype()); - Type valueEty = valuesAreInt ? intTy : floatTy; + Type valuesETy = valuesAreInt ? intTy : floatTy; Value off = rewriter.create(loc, selectTy, values, zero, zero); - off = rewriter.create(loc, valueEty, off); + off = rewriter.create(loc, valuesETy, off); Value on = rewriter.create(loc, selectTy, values, zero, one); - on = rewriter.create(loc, valueEty, on); + on = rewriter.create(loc, valuesETy, on); auto i32Ty = rewriter.getIntegerType(32, true); llvm::SmallVector onehotShape(indicesTy.getSizes()); @@ -1806,7 +1806,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( onehotTy = rewriter.getType(onehotShape, i32Ty); - onehot = rewriter.create(loc, onehotTy, + onehot = rewriter.create(loc, resultType, onehot, iv1, iv0); } diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 2703d4872..4543c5e5e 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -439,9 +439,10 @@ public: target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index ce88854f1..6ca4fb205 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7174,10 +7174,8 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); int64_t inputRank = inputType.getSizes().size(); - int64_t numClasses; - if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses))) - return rewriter.notifyMatchFailure( - op, "unimplemented: num_classes must be constant"); + int64_t numClasses = Torch::kUnknownSize; + matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses)); Value none = rewriter.create(loc); // arange tensor diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 9f23229d5..ca2926ae1 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -326,3 +326,14 @@ func.func @torch.aten.Int.bool(%arg0: !torch.bool) -> !torch.int { %0 = torch.aten.Int.bool %arg0 : !torch.bool -> !torch.int return %0 : !torch.int } + +// CHECK-LABEL: func.func @torch.aten.Int.Scalar( +// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int { +// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]] +// CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[ARG_F64]] : f64 to i64 +// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[FPTOSI]] +// CHECK: return %[[OUT]] : !torch.int +func.func @torch.aten.Int.Scalar(%arg0: !torch.float) -> !torch.int { + %0 = torch.aten.Int.Scalar %arg0 : !torch.float -> !torch.int + return %0 : !torch.int +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 530160f99..a3711c15e 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -78,3 +78,22 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch %0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16> return %0 : !torch.tensor<[?], f16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.one_hot$fold( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT4:.*]] = torch.constant.int 4 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[ARANGE:.*]] = torch.aten.arange.start_step %[[INT0]], %arg1, %[[INT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64> +// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %[[ARG_0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,1],si64> +// CHECK: %[[EQ:.*]] = torch.aten.eq.Tensor %[[UNSQUEEZE]], %[[ARANGE]] : !torch.vtensor<[3,1],si64>, !torch.vtensor<[?],si64> -> !torch.vtensor<[3,?],i1> +// CHECK: %[[RESULT:.*]] = torch.aten.to.dtype %[[EQ]], %[[INT4]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,?],si64> +// CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,?],si64> +func.func @torch.aten.one_hot$fold(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> { + %0 = torch.aten.one_hot %arg0, %arg1 : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si64> + return %0 : !torch.vtensor<[3,?],si64> +}