[Torch] Fold aten rounding ops on splat constants.#4359
[Torch] Fold aten rounding ops on splat constants.#4359
Conversation
This commit teaches the folding methods of `AtenFloor`, `AtenCeil`, `AtenRound`, and `AtenTruc` to constant-fold roundings when the operand is a splat `DenseElementsAttr`.
58550c3 to
175d2a2
Compare
|
@sahas3 @zjgarvey @vivekkhandelwal1 can you please take a look at this? |
sahas3
left a comment
There was a problem hiding this comment.
Nice change, looks good to me for most part.
| } | ||
|
|
||
| // Common helper for splat-only rounding-based folders. | ||
| static OpFoldResult foldSplatRounding(ValueTensorType resultType, |
There was a problem hiding this comment.
I think naming it foldFloatSplatWithRounding will be more appropriate since it only handles float data.
| return {}; | ||
|
|
||
| auto outShaped = resultType.toBuiltinTensor(); | ||
| if (!outShaped.hasStaticShape()) |
There was a problem hiding this comment.
why is static shape a requirement?
| @export | ||
| @annotate_args([None]) | ||
| def forward(self): | ||
| return torch.ops.aten.round(self.const) |
There was a problem hiding this comment.
I am guessing the other ops are already covered in e2e tests?
| // CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-1.000000e+00> : tensor<2x2xf32>) | ||
| // CHECK: return %[[C]] | ||
| func.func @torch.aten.ceil$fold() -> !torch.vtensor<[2,2],f32> { | ||
| %cst = torch.vtensor.literal(dense<-1.100000e+00> : tensor<2x2xf32>) |
There was a problem hiding this comment.
Can you add a +ve value too for completeness?
| // CHECK: %[[C:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<3x4xf32>) | ||
| // CHECK: return %[[C]] | ||
| func.func @torch.aten.floor$fold() -> !torch.vtensor<[3,4],f32> { | ||
| %cst = torch.vtensor.literal(dense<1.900000e+00> : tensor<3x4xf32>) |
There was a problem hiding this comment.
Similarly adding a negative value will give full coverage
| // NaNs and infs are dealt with consistently with torch, so side-effects | ||
| // can be discarded. |
There was a problem hiding this comment.
Clarification: Can you elaborate on what you meant by "NaNs and infs are dealt with consistently with torch" ?
This commit teaches the folding methods of
AtenFloor,AtenCeil,AtenRound, andAtenTructo constant-fold roundings when the operand is a splatDenseElementsAttr.