[MLIR][TORCH] Add e2e support for aten.randint

-- This commit adds e2e support for aten.randint by decomposing it into
   an aten.randint.low by setting low=0.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
pull/2003/head
Abhishek Varma 2023-04-04 09:31:21 +00:00 committed by Abhishek Varma
parent 0497f0b08d
commit 5337944ddb
8 changed files with 129 additions and 0 deletions

View File

@ -3742,6 +3742,34 @@ def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [
}];
}
def Torch_AtenRandintOp : Torch_Op<"aten.randint", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
Torch_IntType:$high,
AnyTorchListOfTorchIntType:$size,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRandintOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRandintOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenBernoulliTensorOp : Torch_Op<"aten.bernoulli.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -6881,6 +6881,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randint\"(%arg0: !torch.int, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randn\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"

View File

@ -3814,6 +3814,28 @@ public:
};
} // namespace
namespace {
class DecomposeAtenRandintOp : public OpRewritePattern<AtenRandintOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRandintOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type resultType = op.getType();
Value low = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenRandintLowOp>(
op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory());
return success();
}
};
} // namespace
namespace {
// Decompose `aten.varMean.correction` op into `aten.var.correction` and
// `aten.mean.dim` op.
@ -4287,6 +4309,7 @@ public:
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNormScalarOptDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);

View File

@ -461,6 +461,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenRandintOp>();
target.addIllegalOp<AtenVarMeanCorrectionOp>();
target.addIllegalOp<PrimsConvertElementTypeOp>();
target.addIllegalOp<PrimsVarOp>();

View File

@ -1157,6 +1157,17 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}
if (auto randInt = dyn_cast<AtenRandintOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type defaultDtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
knowledge.dtype =
getDtypeOrDefault(op->getContext(), randInt.getDtype(), defaultDtype);
incorporateKnowledge(randInt.getResult(), knowledge);
return;
}
if (isa<AtenVarMeanCorrectionOp, AtenVarMeanOp>(op)) {
auto input = operands[0]->getValue();
auto knowledge =

View File

@ -625,6 +625,9 @@ def atenrandn_like〡shape(self: List[int], dtype: Optional[int] = None, layo
def atenrandintlow〡shape(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
def atenrandint〡shape(high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
def atenrandn〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size

View File

@ -334,6 +334,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)")

View File

@ -329,6 +329,65 @@ def RandIntLowDtypeModule_basic(module, tu: TestUtils):
# ==============================================================================
class RandIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = torch.ops.aten.randint(high=1000, size=[1024, 1024])
mean = torch.mean(a.to(torch.float32))
return mean
@register_test_case(module_factory=lambda: RandIntModule())
def RandIntModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class RandIntDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], dtype=torch.float64)
mean = torch.mean(a.to(torch.float32))
return mean
@register_test_case(module_factory=lambda: RandIntDtypeModule())
def RandIntDtypeModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class RandIntPinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], pin_memory=False)
mean = torch.mean(a.to(torch.float32))
return mean
@register_test_case(module_factory=lambda: RandIntPinMemoryModule())
def RandIntPinMemoryModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class RandnModule(torch.nn.Module):