mirror of https://github.com/llvm/torch-mlir
parent
3e83a86354
commit
72c3326097
|
@ -264,6 +264,7 @@ STABLEHLO_PASS_SET = {
|
|||
"Mv_basic",
|
||||
"NativeLayerNormModule4D_basic",
|
||||
"NativeLayerNormModule_basic",
|
||||
"OneHotModule_basic",
|
||||
"PrimsConvertElementTypeModule_basic",
|
||||
"ReduceFrobeniusNormKeepDimModule_basic",
|
||||
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
||||
|
@ -935,4 +936,5 @@ LTC_XFAIL_SET = {
|
|||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||
"PrimsViewOfModule_basic",
|
||||
"PrimsViewOfZeroRankModule_basic",
|
||||
"OneHotModule_basic",
|
||||
}
|
||||
|
|
|
@ -6225,6 +6225,30 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::one_hot : (Tensor, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$num_classes
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenOneHotOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenOneHotOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -6435,6 +6435,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: getting num_classes from tensor contents is not supported\"\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %0 = torch.aten.ne.int %arg1, %int-1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %0 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %1 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %2 = torch.aten.add.t %arg0, %1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
|
|
|
@ -4165,6 +4165,65 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
||||
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenOneHotOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto context = op.getContext();
|
||||
|
||||
Value input = op.getSelf();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor should have known sizes.");
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
int64_t numClasses;
|
||||
if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: num_classes must be constant");
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
|
||||
// arange tensor
|
||||
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
||||
auto arangeType =
|
||||
ValueTensorType::get(context, llvm::ArrayRef(numClasses), si64Type);
|
||||
Value arangeTensor = rewriter.create<AtenArangeOp>(
|
||||
loc, arangeType, op.getNumClasses(), /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
|
||||
// unsqueeze input
|
||||
llvm::SmallVector<int64_t> unsqueezeShape(inputType.getSizes());
|
||||
unsqueezeShape.push_back(1);
|
||||
auto unsqueezeType =
|
||||
ValueTensorType::get(context, unsqueezeShape, si64Type);
|
||||
Value unsqueezeTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||
loc, unsqueezeType, input,
|
||||
rewriter.create<ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(inputRank)));
|
||||
|
||||
// compare
|
||||
auto eqType = ValueTensorType::get(
|
||||
context, op.getType().cast<BaseTensorType>().getSizes(),
|
||||
IntegerType::get(context, 1));
|
||||
Value eqTensor = rewriter.create<AtenEqTensorOp>(
|
||||
loc, eqType, unsqueezeTensor, arangeTensor);
|
||||
|
||||
// convert to si64
|
||||
Value si64TypeValue =
|
||||
Torch::getDtypeIntValueForType(rewriter, loc, si64Type);
|
||||
Value result = rewriter.create<AtenToDtypeOp>(
|
||||
loc, op.getType(), eqTensor, si64TypeValue, /*non_blocking=*/falseValue,
|
||||
/*copy=*/falseValue, /*memory_format=*/none);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -4325,6 +4384,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposePrimsSqueezeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
|
|
@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenBucketizeTensorOp>();
|
||||
target.addIllegalOp<PrimsSqueezeOp>();
|
||||
target.addIllegalOp<AtenMovedimIntOp>();
|
||||
target.addIllegalOp<AtenOneHotOp>();
|
||||
for (auto &opName : backendLegalOpsSet) {
|
||||
target.addLegalOp(
|
||||
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||
|
|
|
@ -659,7 +659,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
|
||||
AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
|
||||
AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp,
|
||||
PrimsSqueezeOp>(op)) {
|
||||
PrimsSqueezeOp, AtenOneHotOp>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
||||
|
|
|
@ -356,6 +356,12 @@ def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] =
|
|||
def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
|
||||
# TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor,
|
||||
# making it impossible to add support for it using the current design of the shape library.
|
||||
def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]:
|
||||
assert num_classes != -1, "getting num_classes from tensor contents is not supported"
|
||||
return self + [num_classes]
|
||||
|
||||
def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.argmax(self, dim, keepdim)
|
||||
|
||||
|
|
|
@ -457,6 +457,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)")
|
||||
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
||||
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
|
||||
emit("aten::clone : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
|
||||
|
|
|
@ -3564,3 +3564,22 @@ class PrimsViewOfZeroRankModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: PrimsViewOfZeroRankModule())
|
||||
def PrimsViewOfZeroRankModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class OneHotModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1], torch.long, True)])
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.one_hot(x, num_classes=5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: OneHotModule())
|
||||
def OneHotModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(10, high=5))
|
||||
|
|
Loading…
Reference in New Issue