mirror of https://github.com/llvm/torch-mlir
[Torch] Fix bugs for `Torch::AtenOneHotOp` (#3350)
This PR fixes the bugs for `Torch::AtenOneHotOp` by: 1) Using `Torch::kUnknownSize` as the default value for `numClasses` in the pattern matching stage in `DecomposeAtenOneHotOp` 2) Adding `AtenIntScalarOp` to the patterns in `TorchToArith` 3) Handling both `int` and `float` types for `off` and `on` values in `TorchOnnxToTorch` conversion It also includes: 1) A new test in `TorchToArith/basic.mlir`, for `torch.aten.Int.Scalar`, and 2) A new test in `decompose-complex-ops.mlir`, for `torch.aten.one_hot` **Dependencies** This PR is dependent on #3334.pull/3382/head
parent
f4bfe3f948
commit
2e194e13d6
|
@ -1778,15 +1778,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
llvm::SmallVector<int64_t>{1}, valuesTy.getDtype());
|
||||
|
||||
bool valuesAreInt = isa<IntegerType>(valuesTy.getDtype());
|
||||
Type valueEty = valuesAreInt ? intTy : floatTy;
|
||||
Type valuesETy = valuesAreInt ? intTy : floatTy;
|
||||
|
||||
Value off = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
|
||||
values, zero, zero);
|
||||
off = rewriter.create<Torch::AtenItemOp>(loc, valueEty, off);
|
||||
off = rewriter.create<Torch::AtenItemOp>(loc, valuesETy, off);
|
||||
|
||||
Value on = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
|
||||
values, zero, one);
|
||||
on = rewriter.create<Torch::AtenItemOp>(loc, valueEty, on);
|
||||
on = rewriter.create<Torch::AtenItemOp>(loc, valuesETy, on);
|
||||
|
||||
auto i32Ty = rewriter.getIntegerType(32, true);
|
||||
llvm::SmallVector<int64_t> onehotShape(indicesTy.getSizes());
|
||||
|
@ -1806,7 +1806,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
onehotTy =
|
||||
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);
|
||||
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, onehotTy,
|
||||
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, resultType,
|
||||
onehot, iv1, iv0);
|
||||
}
|
||||
|
||||
|
|
|
@ -439,9 +439,10 @@ public:
|
|||
target.addIllegalOp<Torch::ConstantIntOp>();
|
||||
patterns.add<ConvertTorchConstantIntOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenIntBoolOp, AtenFloatScalarOp>();
|
||||
target.addIllegalOp<AtenIntBoolOp, AtenFloatScalarOp, AtenIntScalarOp>();
|
||||
patterns.add<ConvertAtenCastOp<AtenIntBoolOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenCastOp<AtenFloatScalarOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenCastOp<AtenIntScalarOp>>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAddOp>();
|
||||
patterns.add<ConvertAtenAddOp>(typeConverter, context);
|
||||
|
|
|
@ -7174,10 +7174,8 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor should have known sizes.");
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
int64_t numClasses;
|
||||
if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: num_classes must be constant");
|
||||
int64_t numClasses = Torch::kUnknownSize;
|
||||
matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses));
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
||||
// arange tensor
|
||||
|
|
|
@ -326,3 +326,14 @@ func.func @torch.aten.Int.bool(%arg0: !torch.bool) -> !torch.int {
|
|||
%0 = torch.aten.Int.bool %arg0 : !torch.bool -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.Int.Scalar(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int {
|
||||
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
|
||||
// CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[ARG_F64]] : f64 to i64
|
||||
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[FPTOSI]]
|
||||
// CHECK: return %[[OUT]] : !torch.int
|
||||
func.func @torch.aten.Int.Scalar(%arg0: !torch.float) -> !torch.int {
|
||||
%0 = torch.aten.Int.Scalar %arg0 : !torch.float -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
|
|
@ -78,3 +78,22 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch
|
|||
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
|
||||
return %0 : !torch.tensor<[?], f16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.one_hot$fold(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ARANGE:.*]] = torch.aten.arange.start_step %[[INT0]], %arg1, %[[INT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
|
||||
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %[[ARG_0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,1],si64>
|
||||
// CHECK: %[[EQ:.*]] = torch.aten.eq.Tensor %[[UNSQUEEZE]], %[[ARANGE]] : !torch.vtensor<[3,1],si64>, !torch.vtensor<[?],si64> -> !torch.vtensor<[3,?],i1>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.to.dtype %[[EQ]], %[[INT4]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,?],si64>
|
||||
// CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,?],si64>
|
||||
func.func @torch.aten.one_hot$fold(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> {
|
||||
%0 = torch.aten.one_hot %arg0, %arg1 : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si64>
|
||||
return %0 : !torch.vtensor<[3,?],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue