mirror of https://github.com/llvm/torch-mlir
[onnx] Support `onnx.OneHot` lowering to `torch` (#3196)
[onnx] Support `onnx.OneHot` lowering to `torch` Leverage the `aten.onehot` implementation along with `aten.transpose` and `aten.where.scalar`.pull/3242/head
parent
ac85338491
commit
9a12a093a6
|
@ -1566,6 +1566,107 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, resultType, input);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"OneHot", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
llvm::SmallVector<Value> inputs;
|
||||
Torch::ValueTensorType resultType;
|
||||
if (binder.tensorOperandsList(inputs) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != 3)
|
||||
return rewriter.notifyMatchFailure(binder.op, "expected 3 operands");
|
||||
|
||||
int64_t axis;
|
||||
if (binder.s64IntegerAttr(axis, "axis", -1))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"`axis` attr not found");
|
||||
|
||||
auto loc = binder.getLoc();
|
||||
Value indices = inputs[0];
|
||||
Value depth = inputs[1];
|
||||
Value values = inputs[2];
|
||||
|
||||
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
|
||||
auto valuesTy = cast<Torch::ValueTensorType>(values.getType());
|
||||
auto depthTy = cast<Torch::ValueTensorType>(depth.getType());
|
||||
|
||||
axis = axis < 0 ? axis + indicesTy.getSizes().size() + 1 : axis;
|
||||
|
||||
bool depthIsInt = isa<IntegerType>(depthTy.getDtype());
|
||||
Type intTy = rewriter.getType<Torch::IntType>();
|
||||
Type floatTy = rewriter.getType<Torch::FloatType>();
|
||||
Type depthETy = depthIsInt ? intTy : floatTy;
|
||||
depth = rewriter.create<Torch::AtenItemOp>(loc, depthETy, depth);
|
||||
|
||||
if (!depthIsInt)
|
||||
depth = rewriter.create<Torch::AtenIntScalarOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), depth);
|
||||
|
||||
auto selectTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{1}, valuesTy.getDtype());
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
|
||||
Value off = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
|
||||
values, zero, zero);
|
||||
off = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), off);
|
||||
|
||||
Value on = rewriter.create<Torch::AtenSelectIntOp>(loc, selectTy,
|
||||
values, zero, one);
|
||||
on = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), on);
|
||||
|
||||
auto i32Ty = rewriter.getIntegerType(32, true);
|
||||
llvm::SmallVector<int64_t> onehotShape(indicesTy.getSizes());
|
||||
onehotShape.push_back(Torch::kUnknownSize);
|
||||
auto onehotTy =
|
||||
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);
|
||||
|
||||
Value onehot = rewriter.create<Torch::AtenOneHotOp>(
|
||||
binder.getLoc(), onehotTy, indices, depth);
|
||||
|
||||
for (int i = valuesTy.getSizes().size(); i > axis; ++i) {
|
||||
std::swap(onehotShape[i - 1], onehotShape[i]);
|
||||
Value iv0 = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
Value iv1 = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i - 1));
|
||||
|
||||
onehotTy =
|
||||
rewriter.getType<Torch::ValueTensorType>(onehotShape, i32Ty);
|
||||
onehot = rewriter.create<Torch::AtenTransposeIntOp>(loc, onehotTy,
|
||||
onehot, iv1, iv0);
|
||||
}
|
||||
|
||||
// Change one hot to an array of booleans to select value:
|
||||
auto i1Ty = rewriter.getI1Type();
|
||||
auto torchqTy = Torch::getScalarTypeForType(i1Ty);
|
||||
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
static_cast<int64_t>(torchqTy)));
|
||||
|
||||
onehotTy = rewriter.getType<Torch::ValueTensorType>(onehotShape, i1Ty);
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
onehot = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
loc, onehotTy, onehot, tyConst,
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/none);
|
||||
|
||||
onehotTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
onehotShape, resultType.getDtype());
|
||||
onehot = rewriter.create<Torch::AtenWhereScalarOp>(loc, onehotTy,
|
||||
onehot, on, off);
|
||||
|
||||
rewriter.replaceOp(binder.op, onehot);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("HardSwish", 14,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -2644,9 +2644,6 @@ ONNX_XFAIL_SET = {
|
|||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||
"MaxPool2dWithIndicesStaticModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.OneHot
|
||||
"OneHotModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.ReduceProd
|
||||
"ReduceProdFloatModule_basic",
|
||||
|
@ -2655,7 +2652,7 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceProdUnsignedIntModule_basic",
|
||||
"ReduceProdSignedIntModule_basic",
|
||||
"ReduceProdDtypeIntModule_basic",
|
||||
|
||||
|
||||
# ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64)
|
||||
"RandnDtypeDeviceModule_basic",
|
||||
"RandnGeneratorF64Module_basic",
|
||||
|
@ -2679,7 +2676,7 @@ ONNX_XFAIL_SET = {
|
|||
"ScatterReduceIntMaxModuleIncludeSelf",
|
||||
"ScatterReduceIntMinModuleIncludeSelf",
|
||||
"ScatterValueFloatModule_basic",
|
||||
|
||||
|
||||
# Failure - onnx_lowering: onnx.ScatterND
|
||||
"IndexPut1DFloatAccumulateModule_basic",
|
||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||
|
|
Loading…
Reference in New Issue