Added tensorResultTypeAtIndex to Patterns.h

Need this for LayerNorm
pull/2722/head
Andreas Falkenberg 2023-12-29 11:21:55 -08:00 committed by Vivek Khandelwal
parent 9adad9bc40
commit 80bd093d56
1 changed files with 10 additions and 0 deletions

View File

@ -95,6 +95,16 @@ struct OpBinder {
return success();
}
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) {
if (idx >= op->getNumResults())
return failure();
auto t = toValidTensorType(op->getResult(idx).getType());
if (!t)
return failure();
typeIdx = t;
return success();
}
// Attribute accessors.
ParseResult s64BoolAttr(bool &value, StringRef nameSuffix,
bool defaultValue = false) {