diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index e8ac725fc..c0df70b12 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1566,6 +1566,107 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, input); return success(); }); + patterns.onOp( + "OneHot", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + llvm::SmallVector inputs; + Torch::ValueTensorType resultType; + if (binder.tensorOperandsList(inputs) || + binder.tensorResultType(resultType)) + return failure(); + + if (inputs.size() != 3) + return rewriter.notifyMatchFailure(binder.op, "expected 3 operands"); + + int64_t axis; + if (binder.s64IntegerAttr(axis, "axis", -1)) + return rewriter.notifyMatchFailure(binder.op, + "`axis` attr not found"); + + auto loc = binder.getLoc(); + Value indices = inputs[0]; + Value depth = inputs[1]; + Value values = inputs[2]; + + auto indicesTy = cast(indices.getType()); + auto valuesTy = cast(values.getType()); + auto depthTy = cast(depth.getType()); + + axis = axis < 0 ? axis + indicesTy.getSizes().size() + 1 : axis; + + bool depthIsInt = isa(depthTy.getDtype()); + Type intTy = rewriter.getType(); + Type floatTy = rewriter.getType(); + Type depthETy = depthIsInt ? intTy : floatTy; + depth = rewriter.create(loc, depthETy, depth); + + if (!depthIsInt) + depth = rewriter.create( + loc, rewriter.getType(), depth); + + auto selectTy = rewriter.getType( + llvm::SmallVector{1}, valuesTy.getDtype()); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + Value off = rewriter.create(loc, selectTy, + values, zero, zero); + off = rewriter.create( + loc, rewriter.getType(), off); + + Value on = rewriter.create(loc, selectTy, + values, zero, one); + on = rewriter.create( + loc, rewriter.getType(), on); + + auto i32Ty = rewriter.getIntegerType(32, true); + llvm::SmallVector onehotShape(indicesTy.getSizes()); + onehotShape.push_back(Torch::kUnknownSize); + auto onehotTy = + rewriter.getType(onehotShape, i32Ty); + + Value onehot = rewriter.create( + binder.getLoc(), onehotTy, indices, depth); + + for (int i = valuesTy.getSizes().size(); i > axis; ++i) { + std::swap(onehotShape[i - 1], onehotShape[i]); + Value iv0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + Value iv1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(i - 1)); + + onehotTy = + rewriter.getType(onehotShape, i32Ty); + onehot = rewriter.create(loc, onehotTy, + onehot, iv1, iv0); + } + + // Change one hot to an array of booleans to select value: + auto i1Ty = rewriter.getI1Type(); + auto torchqTy = Torch::getScalarTypeForType(i1Ty); + Value tyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(torchqTy))); + + onehotTy = rewriter.getType(onehotShape, i1Ty); + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + onehot = rewriter.create( + loc, onehotTy, onehot, tyConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + + onehotTy = rewriter.getType( + onehotShape, resultType.getDtype()); + onehot = rewriter.create(loc, onehotTy, + onehot, on, off); + + rewriter.replaceOp(binder.op, onehot); + return success(); + }); patterns.onOp("HardSwish", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 55a005e68..2b47c40c4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2644,9 +2644,6 @@ ONNX_XFAIL_SET = { "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", "MaxPool2dWithIndicesStaticModule_basic", - - # Failure - onnx_lowering: onnx.OneHot - "OneHotModule_basic", # Failure - onnx_lowering: onnx.ReduceProd "ReduceProdFloatModule_basic", @@ -2655,7 +2652,7 @@ ONNX_XFAIL_SET = { "ReduceProdUnsignedIntModule_basic", "ReduceProdSignedIntModule_basic", "ReduceProdDtypeIntModule_basic", - + # ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64) "RandnDtypeDeviceModule_basic", "RandnGeneratorF64Module_basic", @@ -2679,7 +2676,7 @@ ONNX_XFAIL_SET = { "ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf", "ScatterValueFloatModule_basic", - + # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic",