[onnx] Support `onnx.OneHot` lowering to `torch` (#3196)

[onnx] Support `onnx.OneHot` lowering to `torch`

Leverage the `aten.onehot` implementation along with `aten.transpose`
and `aten.where.scalar`.
pull/3242/head
Rob Suderman 2024-04-26 12:08:15 -07:00 committed by GitHub
parent ac85338491
commit 9a12a093a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 5 deletions

View File

@ -1566,6 +1566,107 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, input);
return success();
});
patterns.onOp(
"OneHot", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
llvm::SmallVector<Value> inputs;
Torch::ValueTensorType resultType;
if (binder.tensorOperandsList(inputs) ||
binder.tensorResultType(resultType))
return failure();
if (inputs.size() != 3)
return rewriter.notifyMatchFailure(binder.op, "expected 3 operands");
int64_t axis;
if (binder.s64IntegerAttr(axis, "axis", -1))
return rewriter.notifyMatchFailure(binder.op,
"`axis` attr not found");
auto loc = binder.getLoc();
Value indices = inputs[0];
Value depth = inputs[1];
Value values = inputs[2];
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
auto valuesTy = cast<Torch::ValueTensorType>(values.getType());
auto depthTy = cast<Torch::ValueTensorType>(depth.getType());
axis = axis < 0 ? axis + indicesTy.getSizes().size() + 1 : axis;
bool depthIsInt = isa<IntegerType>(depthTy.getDtype());
Type intTy = rewriter.getType<Torch::IntType>();
Type floatTy = rewriter.getType<Torch::FloatType>();
Type depthETy = depthIsInt ? intTy : floatTy;
depth = rewriter.create<Torch::AtenItemOp>(loc, depthETy, depth);
if (!depthIsInt)
depth = rewriter.create<Torch::AtenIntScalarOp>(
loc, rewriter.getType<Torch::IntType>(), depth);
auto selectTy = rewriter.getType<Torch::ValueTensorType>(
llvm::SmallVector<int64_t>{1}, valuesTy.getDtype());
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value off = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
values, zero, zero);
off = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), off);
Value on = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
values, zero, one);
on = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), on);
auto i32Ty = rewriter.getIntegerType(32, true);
llvm::SmallVector<int64_t> onehotShape(indicesTy.getSizes());
onehotShape.push_back(Torch::kUnknownSize);
auto onehotTy =
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);
Value onehot = rewriter.create<Torch::AtenOneHotOp>(
binder.getLoc(), onehotTy, indices, depth);
for (int i = valuesTy.getSizes().size(); i > axis; ++i) {
std::swap(onehotShape[i - 1], onehotShape[i]);
Value iv0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
Value iv1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i - 1));
onehotTy =
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, onehotTy,
onehot, iv1, iv0);
}
// Change one hot to an array of booleans to select value:
auto i1Ty = rewriter.getI1Type();
auto torchqTy = Torch::getScalarTypeForType(i1Ty);
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
static_cast<int64_t>(torchqTy)));
onehotTy = rewriter.getType<Torch::ValueTensorType>(onehotShape, i1Ty);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
onehot = rewriter.create<Torch::AtenToDtypeOp>(
loc, onehotTy, onehot, tyConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
onehotTy = rewriter.getType<Torch::ValueTensorType>(
onehotShape, resultType.getDtype());
onehot = rewriter.create<Torch::AtenWhereScalarOp>(loc, onehotTy,
onehot, on, off);
rewriter.replaceOp(binder.op, onehot);
return success();
});
patterns.onOp("HardSwish", 14,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -2644,9 +2644,6 @@ ONNX_XFAIL_SET = {
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
# Failure - onnx_lowering: onnx.OneHot
"OneHotModule_basic",
# Failure - onnx_lowering: onnx.ReduceProd
"ReduceProdFloatModule_basic",
@ -2655,7 +2652,7 @@ ONNX_XFAIL_SET = {
"ReduceProdUnsignedIntModule_basic",
"ReduceProdSignedIntModule_basic",
"ReduceProdDtypeIntModule_basic",
# ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64)
"RandnDtypeDeviceModule_basic",
"RandnGeneratorF64Module_basic",
@ -2679,7 +2676,7 @@ ONNX_XFAIL_SET = {
"ScatterReduceIntMaxModuleIncludeSelf",
"ScatterReduceIntMinModuleIncludeSelf",
"ScatterValueFloatModule_basic",
# Failure - onnx_lowering: onnx.ScatterND
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",