import cairo
import math

class SpriteAnimator:
    def __init__(self, filename, frames = 0):
        self.filename = filename
        self.frames = frames
        self.currentFrame = 0
        self.baseImage = cairo.ImageSurface.create_from_png(filename)
        self.frameHeight = self.baseImage.get_height()

        if self.frames == 0:
            self.frames = self.baseImage.get_width() / self.frameHeight
        else:
            self.frames = frames

        self.frameWidth = self.baseImage.get_width() / self.frames

    def reset(self):
        self.currentFrame = 0

    def nextFrame(self):
        if self.currentFrame == self.frames - 1:
            self.currentFrame = 0
        else:
            self.currentFrame += 1

    def prevFrame(self):
        if self.currentFrame == 0:
            self.currentFrame = self.frames - 1
        else:
            self.currentFrame -= 1

    def drawFrame(self, ctx, x, y):
        drawSubImage(ctx, self.baseImage, x, y,
                     self.frameWidth, self.frameHeight,
                     self.frameWidth * self.currentFrame, 0)


def drawSubImage(ctx, image, destX, destY, destWidth, destHeight,
                 sourceX, sourceY):
    ctx.save()
    ctx.rectangle(destX, destY, destWidth, destHeight)
    ctx.clip()
    ctx.new_path()
    ctx.set_source_surface(image, destX - sourceX, destY - sourceY)
    ctx.paint()
    ctx.restore()

def drawImage(ctx, image, destX, destY):
    ctx.set_source_surface(image, destX, destY)
    ctx.paint()

def drawTiles(ctx, image, destX, destY, destWidth, destHeight, borderWidth=0):
    imageWidth     = image.get_width()
    imageHeight    = image.get_height()
    baseTileWidth  = imageWidth  - 2 * borderWidth
    baseTileHeight = imageHeight - 2 * borderWidth
    numTilesX      = int(math.ceil(float(destWidth)  / baseTileWidth))
    numTilesY      = int(math.ceil(float(destHeight) / baseTileHeight))

    y = destY + borderWidth

    # Draw the corners
    drawSubImage(ctx, image, destX, destY, borderWidth, borderWidth, 0, 0)
    drawSubImage(ctx, image, destWidth - destX - borderWidth, destY,
                 borderWidth, borderWidth, imageWidth - borderWidth, 0)
    drawSubImage(ctx, image, destX, destHeight - destY - borderWidth,
                 borderWidth, borderWidth, 0, imageHeight - borderWidth)
    drawSubImage(ctx, image, destWidth - destX - borderWidth,
                 destHeight - destY - borderWidth, borderWidth, borderWidth,
                 imageWidth - borderWidth, imageHeight - borderWidth)

    for r in range(numTilesY):
        x = destX + borderWidth
        tileHeight = baseTileHeight
        if y + baseTileHeight > destHeight - borderWidth:
            tileHeight = destHeight - y - borderWidth

        if borderWidth > 0:
            drawSubImage(ctx, image, destX, y, borderWidth, tileHeight,
                         0, borderWidth)
            drawSubImage(ctx, image, destX + destWidth - borderWidth,
                         y, borderWidth, tileHeight, imageWidth - borderWidth,
                         borderWidth)

        for c in range(numTilesX):
            tileWidth = baseTileWidth
            if x + baseTileWidth > destWidth - borderWidth:
                tileWidth = destWidth - x - borderWidth

            if borderWidth > 0:
                if r == 0:
                    drawSubImage(ctx, image, x, destY, tileWidth,
                                 borderWidth, borderWidth, 0)

                if r == numTilesY - 1:
                    drawSubImage(ctx, image, x, y + tileHeight, tileWidth,
                                 borderWidth, borderWidth,
                                 imageHeight - borderWidth)

            drawSubImage(ctx, image, x, y, tileWidth, tileHeight,
                         borderWidth, borderWidth)

            x += tileWidth

        y += tileHeight
