[Torch] Fix torch.constant.int operation parsing (#3476)

Due to the custom operation parser, the print and parser were expecting
two different forms.

One having the dictionary before the value and the other after.
Following the format of the other constants ops, the constant.int will
follow the `value attr-dict` format. Updated the parser accordingly.
pull/3512/head
Christopher McGirr 2024-06-28 16:06:52 +02:00 committed by GitHub
parent 23e3c0b5d2
commit 7e6d76e997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 2 deletions

View File

@ -2882,11 +2882,11 @@ void ConstantDeviceOp::getAsmResultNames(
ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) {
Builder builder(result.getContext());
result.addTypes(builder.getType<Torch::IntType>());
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
int64_t value;
if (parser.parseInteger(value))
return failure();
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
result.addAttribute("value", builder.getI64IntegerAttr(value));
return success();
}

View File

@ -93,6 +93,9 @@ func.func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
// CHECK: %int-3 = torch.constant.int -3
%int-3 = torch.constant.int -3
// CHECK: %int5 = torch.constant.int 5 {test = "value"}
%int5 = torch.constant.int 5 {test = "value"}
// CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00
%float1.000000e00 = torch.constant.float 1.000000e+00
// CHECK: %float-1.000000e00 = torch.constant.float -1.000000e+00