[Torch] Fix bugs for `Torch::AtenOneHotOp` (#3350)

This PR fixes the bugs for `Torch::AtenOneHotOp` by:

1) Using `Torch::kUnknownSize` as the default value for `numClasses` in
   the pattern matching stage in `DecomposeAtenOneHotOp`
2) Adding `AtenIntScalarOp` to the patterns in `TorchToArith`
3) Handling both `int` and `float` types for `off` and `on` values in
`TorchOnnxToTorch` conversion

It also includes:

1) A new test in `TorchToArith/basic.mlir`, for `torch.aten.Int.Scalar`,
and
2) A new test in `decompose-complex-ops.mlir`, for `torch.aten.one_hot`

**Dependencies**

This PR is dependent on #3334.
pull/3382/head
Angel Zhang 2024-05-22 13:19:08 -04:00 committed by GitHub
parent f4bfe3f948
commit 2e194e13d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 9 deletions

View File

@ -1778,15 +1778,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
llvm::SmallVector<int64_t>{1}, valuesTy.getDtype());
bool valuesAreInt = isa<IntegerType>(valuesTy.getDtype());
Type valueEty = valuesAreInt ? intTy : floatTy;
Type valuesETy = valuesAreInt ? intTy : floatTy;
Value off = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
values, zero, zero);
off = rewriter.create<Torch::AtenItemOp>(loc, valueEty, off);
off = rewriter.create<Torch::AtenItemOp>(loc, valuesETy, off);
Value on = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
values, zero, one);
on = rewriter.create<Torch::AtenItemOp>(loc, valueEty, on);
on = rewriter.create<Torch::AtenItemOp>(loc, valuesETy, on);
auto i32Ty = rewriter.getIntegerType(32, true);
llvm::SmallVector<int64_t> onehotShape(indicesTy.getSizes());
@ -1806,7 +1806,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
onehotTy =
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, onehotTy,
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, resultType,
onehot, iv1, iv0);
}

View File

@ -439,9 +439,10 @@ public:
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantIntOp>(typeConverter, context);
target.addIllegalOp<AtenIntBoolOp, AtenFloatScalarOp>();
target.addIllegalOp<AtenIntBoolOp, AtenFloatScalarOp, AtenIntScalarOp>();
patterns.add<ConvertAtenCastOp<AtenIntBoolOp>>(typeConverter, context);
patterns.add<ConvertAtenCastOp<AtenFloatScalarOp>>(typeConverter, context);
patterns.add<ConvertAtenCastOp<AtenIntScalarOp>>(typeConverter, context);
target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context);

View File

@ -7174,10 +7174,8 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
int64_t inputRank = inputType.getSizes().size();
int64_t numClasses;
if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses)))
return rewriter.notifyMatchFailure(
op, "unimplemented: num_classes must be constant");
int64_t numClasses = Torch::kUnknownSize;
matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses));
Value none = rewriter.create<ConstantNoneOp>(loc);
// arange tensor

View File

@ -326,3 +326,14 @@ func.func @torch.aten.Int.bool(%arg0: !torch.bool) -> !torch.int {
%0 = torch.aten.Int.bool %arg0 : !torch.bool -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.Int.Scalar(
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int {
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
// CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[ARG_F64]] : f64 to i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[FPTOSI]]
// CHECK: return %[[OUT]] : !torch.int
func.func @torch.aten.Int.Scalar(%arg0: !torch.float) -> !torch.int {
%0 = torch.aten.Int.Scalar %arg0 : !torch.float -> !torch.int
return %0 : !torch.int
}

View File

@ -78,3 +78,22 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
return %0 : !torch.tensor<[?], f16>
}
// -----
// CHECK-LABEL: func.func @torch.aten.one_hot$fold(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[ARANGE:.*]] = torch.aten.arange.start_step %[[INT0]], %arg1, %[[INT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %[[ARG_0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,1],si64>
// CHECK: %[[EQ:.*]] = torch.aten.eq.Tensor %[[UNSQUEEZE]], %[[ARANGE]] : !torch.vtensor<[3,1],si64>, !torch.vtensor<[?],si64> -> !torch.vtensor<[3,?],i1>
// CHECK: %[[RESULT:.*]] = torch.aten.to.dtype %[[EQ]], %[[INT4]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,?],si64>
// CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,?],si64>
func.func @torch.aten.one_hot$fold(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> {
%0 = torch.aten.one_hot %arg0, %arg1 : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si64>
return %0 : !torch.vtensor<[3,?],si64>
}