[Torch Dialect] Support aten.device.with_index (#2254)

pull/2143/head
Yuanqiang Liu 2023-06-23 01:07:14 +08:00 committed by GitHub
parent 4fd4477e15
commit 96b14e952e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 0 deletions

View File

@ -6996,6 +6996,31 @@ def Torch_AtenDetachOp : Torch_Op<"aten.detach", [
let hasFolder = 1;
}
def Torch_AtenDeviceWithIndexOp : Torch_Op<"aten.device.with_index", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::device.with_index : (str, int) -> (Device)`";
let arguments = (ins
Torch_StringType:$type,
Torch_IntType:$index
);
let results = (outs
Torch_DeviceType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDeviceWithIndexOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenDeviceWithIndexOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenCudaOp : Torch_Op<"aten.cuda", [
AllowsTypeRefinement,
ReadOnly

View File

@ -2455,6 +2455,29 @@ void AtenCudaOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// AtenDeviceWithIndexOp
//===----------------------------------------------------------------------===//
void AtenDeviceWithIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenDeviceWithIndexOp op, PatternRewriter &rewriter) {
std::string type;
int64_t index;
if (!matchPattern(op.getType(), m_TorchConstantStr(type))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: type must be a constant string");
}
if (!matchPattern(op.getIndex(), m_TorchConstantInt(&index))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index must be a constant integer");
}
rewriter.replaceOpWithNewOp<Torch::ConstantDeviceOp>(
op, type + ":" + std::to_string(index));
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenIntTensorOp
//===----------------------------------------------------------------------===//

View File

@ -484,6 +484,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)")
emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True)
emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True)
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)")

View File

@ -2001,3 +2001,13 @@ func.func @torch.aten.cuda$canonicalize(%arg0: !torch.tensor) -> !torch.tensor {
%0 = torch.aten.cuda %arg0 : !torch.tensor -> !torch.tensor
return %0 : !torch.tensor
}
// CHECK-LABEL: func.func @torch.aten.device.with_index$canonicalize
// CHECK-NEXT: %[[VAL:.*]] = torch.constant.device "cuda:0"
// CHECK-NEXT: return %[[VAL]] : !torch.Device
func.func @torch.aten.device.with_index$canonicalize() -> !torch.Device {
%str = torch.constant.str "cuda"
%int0 = torch.constant.int 0
%0 = torch.aten.device.with_index %str, %int0 : !torch.str, !torch.int -> !torch.Device
return %0 : !torch.Device
}