Add Shape inference for CopyOp for lazy tensor core backend (#2006)

- Add Shape inference for CopyOp for LTC backend
pull/2013/head
rahuls-cerebras 2023-04-12 19:07:03 +05:30 committed by GitHub
parent 224ee27610
commit c2c96c430a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 7 deletions

View File

@ -895,24 +895,17 @@ LTC_XFAIL_SET = {
"VarMeanCorrectionModule_basic",
"VarMeanCorrectionNoneModule_basic",
"PrimsConvertElementTypeModule_basic",
"CopyModule_basic",
"CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic",
"CopyWithDifferentSizesModule_basic",
"ElementwisePreluModule_basic",
"VarMeanBiasedModule_basic",
"VarMeanUnbiasedModule_basic",
"RandnLikeModule_basic",
"RandnLikeDtypeModule_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliPModule_basic",
"DropoutTrainModule_basic",
"StdCorrectionKeepDimModule_basic",
"StdCorrectionNoneModule_basic",
"SliceCopy_Module_basic",
"SliceCopyNegative_Module_basic",
"VarBiasedModule_basic",
"VarCorrectionAllDimReduceModule_basic",
"VarCorrectionEmptyDimModule_basic",

View File

@ -56,5 +56,12 @@ std::vector<torch::lazy::Shape> compute_shape_bucketize(
return {Shape(dtype, self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_copy(
const at::Tensor& self,
const at::Tensor& src,
bool non_blocking) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
} // namespace lazy
} // namespace torch