mirror of https://github.com/llvm/torch-mlir
[onnx] Fix transposition code for `onnx.OneHot` (#3606)
The post onehot transposition code was unexercised. Fixed the test and transformation to check use.pull/3617/head
parent
c8efc201f4
commit
59a4c6fda4
|
@ -2707,7 +2707,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Value onehot = rewriter.create<Torch::AtenOneHotOp>(
|
||||
binder.getLoc(), onehotTy, indices, depth);
|
||||
|
||||
for (int i = valuesTy.getSizes().size(); i > axis; ++i) {
|
||||
for (int i = indicesTy.getSizes().size(); i > axis; --i) {
|
||||
std::swap(onehotShape[i - 1], onehotShape[i]);
|
||||
Value iv0 = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
|
@ -2716,7 +2716,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
onehotTy =
|
||||
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);
|
||||
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, resultType,
|
||||
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, onehotTy,
|
||||
onehot, iv1, iv0);
|
||||
}
|
||||
|
||||
|
|
|
@ -8301,6 +8301,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
|||
op, "input tensor should have known sizes.");
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
int64_t numClasses = Torch::kUnknownSize;
|
||||
auto resultType = cast<ValueTensorType>(op.getType());
|
||||
matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses));
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
||||
|
@ -8313,14 +8314,15 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
|||
/*device=*/none, /*pin_memory=*/none);
|
||||
|
||||
// unsqueeze input
|
||||
llvm::SmallVector<int64_t> unsqueezeShape(inputType.getSizes());
|
||||
unsqueezeShape.push_back(1);
|
||||
auto unsqueezeType =
|
||||
ValueTensorType::get(context, unsqueezeShape, si64Type);
|
||||
Value unsqueezeTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||
loc, unsqueezeType, input,
|
||||
rewriter.create<ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(inputRank)));
|
||||
Value rankV = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputRank));
|
||||
auto unsqueeze = Torch::unsqueezeTensor(rewriter, op, input, rankV);
|
||||
if (failed(unsqueeze))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueeze tensor");
|
||||
|
||||
Value unsqueezeTensor =
|
||||
convertTensorToDtype(rewriter, loc, *unsqueeze, si64Type);
|
||||
|
||||
// compare
|
||||
auto eqType = ValueTensorType::get(
|
||||
|
@ -8330,7 +8332,8 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
|||
loc, eqType, unsqueezeTensor, arangeTensor);
|
||||
|
||||
// convert to si64
|
||||
Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type);
|
||||
Value result =
|
||||
convertTensorToDtype(rewriter, loc, eqTensor, resultType.getDtype());
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1480,7 +1480,7 @@ func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_onehot_negative_indices
|
||||
func.func @test_onehot_negative_indices(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,10],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
func.func @test_onehot_negative_indices(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[10,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[INT:.*]] = torch.aten.Int.Scalar %[[ITEM]] : !torch.float -> !torch.int
|
||||
|
@ -1494,15 +1494,18 @@ func.func @test_onehot_negative_indices(%arg0: !torch.vtensor<[3],si64>, %arg1:
|
|||
// CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg2, %[[C0]], %[[C1]]: !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float
|
||||
// CHECK: %[[ONEHOT:.*]] = torch.aten.one_hot %[[WHERE]], %[[INT]] : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si32>
|
||||
// CHECK: %[[D0:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[D1:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %[[ONEHOT]], %[[D1]], %[[D0]]
|
||||
// CHECK: %[[C11:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %[[ONEHOT]], %[[C11]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,?],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,?],i1>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.where.Scalar %[[DTYPE]], %[[ITEM_1]], %[[ITEM_0]] : !torch.vtensor<[3,?],i1>, !torch.float, !torch.float -> !torch.vtensor<[3,10],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,10],f32>
|
||||
// CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %[[TRANS]], %[[C11]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[?,3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,3],i1>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.where.Scalar %[[DTYPE]], %[[ITEM_1]], %[[ITEM_0]] : !torch.vtensor<[?,3],i1>, !torch.float, !torch.float -> !torch.vtensor<[10,3],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[10,3],f32>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.OneHot"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,10],f32>
|
||||
return %0 : !torch.vtensor<[3,10],f32>
|
||||
%0 = torch.operator "onnx.OneHot"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[10,3],f32>
|
||||
return %0 : !torch.vtensor<[10,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue