diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index d842ea77b..06bbb1ac5 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -95,6 +95,16 @@ struct OpBinder { return success(); } + ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) { + if (idx >= op->getNumResults()) + return failure(); + auto t = toValidTensorType(op->getResult(idx).getType()); + if (!t) + return failure(); + typeIdx = t; + return success(); + } + // Attribute accessors. ParseResult s64BoolAttr(bool &value, StringRef nameSuffix, bool defaultValue = false) {