mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Support aten.device.with_index (#2254)
parent
4fd4477e15
commit
96b14e952e
|
@ -6996,6 +6996,31 @@ def Torch_AtenDetachOp : Torch_Op<"aten.detach", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenDeviceWithIndexOp : Torch_Op<"aten.device.with_index", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::device.with_index : (str, int) -> (Device)`";
|
||||
let arguments = (ins
|
||||
Torch_StringType:$type,
|
||||
Torch_IntType:$index
|
||||
);
|
||||
let results = (outs
|
||||
Torch_DeviceType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenDeviceWithIndexOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenDeviceWithIndexOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenCudaOp : Torch_Op<"aten.cuda", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -2455,6 +2455,29 @@ void AtenCudaOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDeviceWithIndexOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AtenDeviceWithIndexOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](AtenDeviceWithIndexOp op, PatternRewriter &rewriter) {
|
||||
std::string type;
|
||||
int64_t index;
|
||||
if (!matchPattern(op.getType(), m_TorchConstantStr(type))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: type must be a constant string");
|
||||
}
|
||||
if (!matchPattern(op.getIndex(), m_TorchConstantInt(&index))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: index must be a constant integer");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::ConstantDeviceOp>(
|
||||
op, type + ":" + std::to_string(index));
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIntTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -484,6 +484,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
|
||||
emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)")
|
||||
emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True)
|
||||
emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
||||
emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)")
|
||||
|
|
|
@ -2001,3 +2001,13 @@ func.func @torch.aten.cuda$canonicalize(%arg0: !torch.tensor) -> !torch.tensor {
|
|||
%0 = torch.aten.cuda %arg0 : !torch.tensor -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.device.with_index$canonicalize
|
||||
// CHECK-NEXT: %[[VAL:.*]] = torch.constant.device "cuda:0"
|
||||
// CHECK-NEXT: return %[[VAL]] : !torch.Device
|
||||
func.func @torch.aten.device.with_index$canonicalize() -> !torch.Device {
|
||||
%str = torch.constant.str "cuda"
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.device.with_index %str, %int0 : !torch.str, !torch.int -> !torch.Device
|
||||
return %0 : !torch.Device
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue