Skip to content

Commit d8ea781

Browse files
authored
Merge pull request #423 from s22s/fix/357
Fixed `toLayer` on arbitrary RasterFrame.
2 parents 47b42a9 + c7bf0fb commit d8ea781

File tree

11 files changed

+89
-193
lines changed

11 files changed

+89
-193
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,7 @@ object TileRasterizerAggregate {
100100

101101
def apply(tlm: TileLayerMetadata[_], sampler: ResampleMethod): ProjectedRasterDefinition = {
102102
// Try to determine the actual dimensions of our data coverage
103-
val actualSize = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent) // <--- Do we have the math right here?
104-
val cols = actualSize.width
105-
val rows = actualSize.height
103+
val TileDimensions(cols, rows) = tlm.totalDimensions
106104
new ProjectedRasterDefinition(cols, rows, tlm.cellType, tlm.crs, tlm.extent, sampler)
107105
}
108106
}

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/ExtractBits.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
package org.locationtech.rasterframes.expressions.transformers
2323

2424
import geotrellis.raster.Tile
25-
import org.apache.spark.sql.{Column, TypedColumn}
25+
import org.apache.spark.sql.Column
2626
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2727
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2828
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback

core/src/main/scala/org/locationtech/rasterframes/extensions/Implicits.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.SparkConf
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.sql._
3333
import org.apache.spark.sql.types.{MetadataBuilder, Metadata => SMetadata}
34+
import org.locationtech.rasterframes.model.TileDimensions
3435
import spray.json.JsonFormat
3536

3637
import scala.reflect.runtime.universe._
@@ -79,6 +80,15 @@ trait Implicits {
7980
private[rasterframes]
8081
implicit class WithMetadataBuilderMethods(val self: MetadataBuilder)
8182
extends MetadataBuilderMethods
83+
84+
private[rasterframes]
85+
implicit class TLMHasTotalCells(tlm: TileLayerMetadata[_]) {
86+
// TODO: With upgrade to GT 3.1, replace this with the more general `Dimensions[Long]`
87+
def totalDimensions: TileDimensions = {
88+
val gb = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent)
89+
TileDimensions(gb.width, gb.height)
90+
}
91+
}
8292
}
8393

8494
object Implicits extends Implicits

core/src/main/scala/org/locationtech/rasterframes/extensions/RFSpatialColumnMethods.scala renamed to core/src/main/scala/org/locationtech/rasterframes/extensions/LayerSpatialColumnMethods.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.locationtech.rasterframes.encoders.serialized_literal
4141
*
4242
* @since 12/15/17
4343
*/
44-
trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with StandardColumns {
44+
trait LayerSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with StandardColumns {
4545
import Implicits.{WithDataFrameMethods, WithRasterFrameLayerMethods}
4646
import org.locationtech.geomesa.spark.jts._
4747

@@ -112,7 +112,7 @@ trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with Sta
112112
*/
113113
def withCenterLatLng(colName: String = "center"): RasterFrameLayer = {
114114
val key2Center = sparkUdf(keyCol2LatLng)
115-
self.withColumn(colName, key2Center(self.spatialKeyColumn).cast(RFSpatialColumnMethods.LngLatStructType)).certify
115+
self.withColumn(colName, key2Center(self.spatialKeyColumn).cast(LayerSpatialColumnMethods.LngLatStructType)).certify
116116
}
117117

118118
/**
@@ -130,6 +130,6 @@ trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with Sta
130130
}
131131
}
132132

133-
object RFSpatialColumnMethods {
133+
object LayerSpatialColumnMethods {
134134
private[rasterframes] val LngLatStructType = StructType(Seq(StructField("longitude", DoubleType), StructField("latitude", DoubleType)))
135135
}

core/src/main/scala/org/locationtech/rasterframes/extensions/RasterFrameLayerMethods.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ import scala.reflect.runtime.universe._
5252
* @since 7/18/17
5353
*/
5454
trait RasterFrameLayerMethods extends MethodExtensions[RasterFrameLayer]
55-
with RFSpatialColumnMethods with MetadataKeys {
55+
with LayerSpatialColumnMethods with MetadataKeys {
5656
import Implicits.{WithDataFrameMethods, WithRasterFrameLayerMethods}
5757

5858
@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))

core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package org.locationtech.rasterframes.extensions
2323
import org.apache.spark.sql._
2424
import org.apache.spark.sql.functions._
2525
import org.locationtech.rasterframes._
26+
import org.locationtech.rasterframes.encoders.serialized_literal
2627
import org.locationtech.rasterframes.expressions.SpatialRelation
2728
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
2829
import org.locationtech.rasterframes.functions.reproject_and_merge
@@ -89,9 +90,12 @@ object RasterJoin {
8990
// After the aggregation we take all the tiles we've collected and resample + merge
9091
// into LHS extent/CRS.
9192
// Use a representative tile from the left for the tile dimensions
92-
val leftTile = left.tileColumns.headOption.getOrElse(throw new IllegalArgumentException("Need at least one target tile on LHS"))
93+
val destDims = left.tileColumns.headOption
94+
.map(t => rf_dimensions(unresolved(t)))
95+
.getOrElse(serialized_literal(NOMINAL_TILE_DIMS))
96+
9397
val reprojCols = rightAggTiles.map(t => reproject_and_merge(
94-
col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), rf_dimensions(unresolved(leftTile))
98+
col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), destDims
9599
) as t.columnName)
96100

97101
val finalCols = leftAggCols.map(unresolved) ++ reprojCols ++ rightAggOther.map(unresolved)

core/src/main/scala/org/locationtech/rasterframes/extensions/ReprojectToLayer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import org.apache.spark.sql._
2626
import org.apache.spark.sql.functions.broadcast
2727
import org.locationtech.rasterframes._
2828
import org.locationtech.rasterframes.util._
29+
30+
/** Algorithm for projecting an arbitrary RasterFrame into a layer with consistent CRS and gridding. */
2931
object ReprojectToLayer {
3032
def apply(df: DataFrame, tlm: TileLayerMetadata[SpatialKey]): RasterFrameLayer = {
3133
// create a destination dataframe with crs and extend columns
@@ -42,8 +44,9 @@ object ReprojectToLayer {
4244
e = tlm.mapTransform(sk)
4345
} yield (sk, e, crs)
4446

47+
// Create effectively a target RasterFrame, but with no tiles.
4548
val dest = gridItems.toSeq.toDF(SPATIAL_KEY_COLUMN.columnName, EXTENT_COLUMN.columnName, CRS_COLUMN.columnName)
46-
dest.show(false)
49+
4750
val joined = RasterJoin(broadcast(dest), df)
4851

4952
joined.asLayer(SPATIAL_KEY_COLUMN, tlm)

core/src/test/scala/examples/CreatingRasterFrames.scala

Lines changed: 0 additions & 92 deletions
This file was deleted.

core/src/test/scala/examples/MeanValue.scala

Lines changed: 0 additions & 50 deletions
This file was deleted.

core/src/test/scala/org/locationtech/rasterframes/RasterFrameSpec.scala renamed to core/src/test/scala/org/locationtech/rasterframes/RasterLayerSpec.scala

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,21 @@
2323

2424
package org.locationtech.rasterframes
2525

26+
import java.net.URI
2627
import java.sql.Timestamp
2728
import java.time.ZonedDateTime
2829

29-
import org.locationtech.rasterframes.util._
30-
import geotrellis.proj4.LatLng
31-
import geotrellis.raster.render.{ColorMap, ColorRamp}
32-
import geotrellis.raster.{ProjectedRaster, Tile, TileFeature, TileLayout, UByteCellType}
30+
import geotrellis.proj4.{CRS, LatLng}
31+
import geotrellis.raster.{MultibandTile, ProjectedRaster, Raster, Tile, TileFeature, TileLayout, UByteCellType, UByteConstantNoDataCellType}
3332
import geotrellis.spark._
3433
import geotrellis.spark.tiling._
3534
import geotrellis.vector.{Extent, ProjectedExtent}
3635
import org.apache.spark.sql.functions._
37-
import org.apache.spark.sql.{SQLContext, SparkSession}
36+
import org.apache.spark.sql.{Encoders, SQLContext, SparkSession}
3837
import org.locationtech.rasterframes.model.TileDimensions
38+
import org.locationtech.rasterframes.ref.RasterSource
39+
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
40+
import org.locationtech.rasterframes.util._
3941

4042
import scala.util.control.NonFatal
4143

@@ -44,7 +46,7 @@ import scala.util.control.NonFatal
4446
*
4547
* @since 7/10/17
4648
*/
47-
class RasterFrameSpec extends TestEnvironment with MetadataKeys
49+
class RasterLayerSpec extends TestEnvironment with MetadataKeys
4850
with TestData {
4951
import TestData.randomTile
5052
import spark.implicits._
@@ -232,17 +234,40 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys
232234
assert(bounds._2 === SpaceTimeKey(3, 1, now))
233235
}
234236

235-
def basicallySame(expected: Extent, computed: Extent): Unit = {
236-
val components = Seq(
237-
(expected.xmin, computed.xmin),
238-
(expected.ymin, computed.ymin),
239-
(expected.xmax, computed.xmax),
240-
(expected.ymax, computed.ymax)
241-
)
242-
forEvery(components)(c
243-
assert(c._1 === c._2 +- 0.000001)
237+
it("should create layer from arbitrary RasterFrame") {
238+
val src = RasterSource(URI.create("https://raw.githubusercontent.com/locationtech/rasterframes/develop/core/src/test/resources/LC08_RGB_Norfolk_COG.tiff"))
239+
val srcCrs = src.crs
240+
241+
def project(r: Raster[MultibandTile]): Seq[ProjectedRasterTile] =
242+
r.tile.bands.map(b => ProjectedRasterTile(b, r.extent, srcCrs))
243+
244+
val prtEnc = ProjectedRasterTile.prtEncoder
245+
implicit val enc = Encoders.tuple(prtEnc, prtEnc, prtEnc)
246+
247+
val rasters = src.readAll(bands = Seq(0, 1, 2)).map(project).map(p => (p(0), p(1), p(2)))
248+
249+
val df = rasters.toDF("red", "green", "blue")
250+
251+
val crs = CRS.fromString("+proj=utm +zone=18 +datum=WGS84 +units=m +no_defs")
252+
253+
val extent = Extent(364455.0, 4080315.0, 395295.0, 4109985.0)
254+
val layout = LayoutDefinition(extent, TileLayout(2, 2, 32, 32))
255+
256+
val tlm = new TileLayerMetadata[SpatialKey](
257+
UByteConstantNoDataCellType,
258+
layout,
259+
extent,
260+
crs,
261+
KeyBounds(SpatialKey(0, 0), SpatialKey(1, 1))
244262
)
245-
}
263+
val layer = df.toLayer(tlm)
264+
265+
val TileDimensions(cols, rows) = tlm.totalDimensions
266+
val prt = layer.toMultibandRaster(Seq($"red", $"green", $"blue"), cols, rows)
267+
prt.tile.dimensions should be((cols, rows))
268+
prt.crs should be(crs)
269+
prt.extent should be(extent)
270+
}
246271

247272
it("shouldn't clip already clipped extents") {
248273
val rf = TestData.randomSpatialTileLayerRDD(1024, 1024, 8, 8).toLayer
@@ -258,27 +283,8 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys
258283
basicallySame(expected2, computed2)
259284
}
260285

261-
def Greyscale(stops: Int): ColorRamp = {
262-
val colors = (0 to stops)
263-
.map(i {
264-
val c = java.awt.Color.HSBtoRGB(0f, 0f, i / stops.toFloat)
265-
(c << 8) | 0xFF // Add alpha channel.
266-
})
267-
ColorRamp(colors)
268-
}
269-
270-
def render(tile: Tile, tag: String): Unit = {
271-
if(false && !isCI) {
272-
val colors = ColorMap.fromQuantileBreaks(tile.histogram, Greyscale(128))
273-
val path = s"target/${getClass.getSimpleName}_$tag.png"
274-
logger.info(s"Writing '$path'")
275-
tile.color(colors).renderPng().write(path)
276-
}
277-
}
278-
279286
it("should rasterize with a spatiotemporal key") {
280287
val rf = TestData.randomSpatioTemporalTileLayerRDD(20, 20, 2, 2).toLayer
281-
282288
noException shouldBe thrownBy {
283289
rf.toRaster($"tile", 128, 128)
284290
}
@@ -291,7 +297,6 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys
291297
val joinTypes = Seq("inner", "outer", "fullouter", "left_outer", "right_outer", "leftsemi")
292298
forEvery(joinTypes) { jt
293299
val joined = rf1.spatialJoin(rf2, jt)
294-
//println(joined.schema.json)
295300
assert(joined.tileLayerMetadata.isRight)
296301
}
297302
}

0 commit comments

Comments
 (0)